1package ssh
2
3import (
4	"bufio"
5	"bytes"
6	"errors"
7	"fmt"
8	"io"
9	"io/ioutil"
10	"net"
11	"os"
12	"path/filepath"
13
14	log "github.com/hashicorp/go-hclog"
15
16	"golang.org/x/crypto/ssh"
17	"golang.org/x/crypto/ssh/agent"
18)
19
20type comm struct {
21	client  *ssh.Client
22	config  *SSHCommConfig
23	conn    net.Conn
24	address string
25}
26
27// SSHCommConfig is the structure used to configure the SSH communicator.
28type SSHCommConfig struct {
29	// The configuration of the Go SSH connection
30	SSHConfig *ssh.ClientConfig
31
32	// Connection returns a new connection. The current connection
33	// in use will be closed as part of the Close method, or in the
34	// case an error occurs.
35	Connection func() (net.Conn, error)
36
37	// Pty, if true, will request a pty from the remote end.
38	Pty bool
39
40	// DisableAgent, if true, will not forward the SSH agent.
41	DisableAgent bool
42
43	// Logger for output
44	Logger log.Logger
45}
46
47// Creates a new communicator implementation over SSH. This takes
48// an already existing TCP connection and SSH configuration.
49func SSHCommNew(address string, config *SSHCommConfig) (result *comm, err error) {
50	// Establish an initial connection and connect
51	result = &comm{
52		config:  config,
53		address: address,
54	}
55
56	if err = result.reconnect(); err != nil {
57		result = nil
58		return
59	}
60
61	return
62}
63
64func (c *comm) Close() error {
65	var err error
66	if c.conn != nil {
67		err = c.conn.Close()
68	}
69	c.conn = nil
70	c.client = nil
71	return err
72}
73
74func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error {
75	// The target directory and file for talking the SCP protocol
76	target_dir := filepath.Dir(path)
77	target_file := filepath.Base(path)
78
79	// On windows, filepath.Dir uses backslash separators (ie. "\tmp").
80	// This does not work when the target host is unix.  Switch to forward slash
81	// which works for unix and windows
82	target_dir = filepath.ToSlash(target_dir)
83
84	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
85		return scpUploadFile(target_file, input, w, stdoutR, fi)
86	}
87
88	return c.scpSession("scp -vt "+target_dir, scpFunc)
89}
90
91func (c *comm) NewSession() (session *ssh.Session, err error) {
92	if c.client == nil {
93		err = errors.New("client not available")
94	} else {
95		session, err = c.client.NewSession()
96	}
97
98	if err != nil {
99		c.config.Logger.Error("ssh session open error, attempting reconnect", "error", err)
100		if err := c.reconnect(); err != nil {
101			c.config.Logger.Error("reconnect attempt failed", "error", err)
102			return nil, err
103		}
104
105		return c.client.NewSession()
106	}
107
108	return session, nil
109}
110
111func (c *comm) reconnect() error {
112	// Close previous connection.
113	if c.conn != nil {
114		c.Close()
115	}
116
117	var err error
118	c.conn, err = c.config.Connection()
119	if err != nil {
120		// Explicitly set this to the REAL nil. Connection() can return
121		// a nil implementation of net.Conn which will make the
122		// "if c.conn == nil" check fail above. Read here for more information
123		// on this psychotic language feature:
124		//
125		// http://golang.org/doc/faq#nil_error
126		c.conn = nil
127		c.config.Logger.Error("reconnection error", "error", err)
128		return err
129	}
130
131	sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
132	if err != nil {
133		c.config.Logger.Error("handshake error", "error", err)
134		c.Close()
135		return err
136	}
137	if sshConn != nil {
138		c.client = ssh.NewClient(sshConn, sshChan, req)
139	}
140	c.connectToAgent()
141
142	return nil
143}
144
145func (c *comm) connectToAgent() {
146	if c.client == nil {
147		return
148	}
149
150	if c.config.DisableAgent {
151		return
152	}
153
154	// open connection to the local agent
155	socketLocation := os.Getenv("SSH_AUTH_SOCK")
156	if socketLocation == "" {
157		return
158	}
159	agentConn, err := net.Dial("unix", socketLocation)
160	if err != nil {
161		c.config.Logger.Error("could not connect to local agent socket", "socket_path", socketLocation)
162		return
163	}
164	defer agentConn.Close()
165
166	// create agent and add in auth
167	forwardingAgent := agent.NewClient(agentConn)
168	if forwardingAgent == nil {
169		c.config.Logger.Error("could not create agent client")
170		return
171	}
172
173	// add callback for forwarding agent to SSH config
174	// XXX - might want to handle reconnects appending multiple callbacks
175	auth := ssh.PublicKeysCallback(forwardingAgent.Signers)
176	c.config.SSHConfig.Auth = append(c.config.SSHConfig.Auth, auth)
177	agent.ForwardToAgent(c.client, forwardingAgent)
178
179	// Setup a session to request agent forwarding
180	session, err := c.NewSession()
181	if err != nil {
182		return
183	}
184	defer session.Close()
185
186	err = agent.RequestAgentForwarding(session)
187	if err != nil {
188		c.config.Logger.Error("error requesting agent forwarding", "error", err)
189		return
190	}
191	return
192}
193
194func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
195	session, err := c.NewSession()
196	if err != nil {
197		return err
198	}
199	defer session.Close()
200
201	// Get a pipe to stdin so that we can send data down
202	stdinW, err := session.StdinPipe()
203	if err != nil {
204		return err
205	}
206
207	// We only want to close once, so we nil w after we close it,
208	// and only close in the defer if it hasn't been closed already.
209	defer func() {
210		if stdinW != nil {
211			stdinW.Close()
212		}
213	}()
214
215	// Get a pipe to stdout so that we can get responses back
216	stdoutPipe, err := session.StdoutPipe()
217	if err != nil {
218		return err
219	}
220	stdoutR := bufio.NewReader(stdoutPipe)
221
222	// Set stderr to a bytes buffer
223	stderr := new(bytes.Buffer)
224	session.Stderr = stderr
225
226	// Start the sink mode on the other side
227	if err := session.Start(scpCommand); err != nil {
228		return err
229	}
230
231	// Call our callback that executes in the context of SCP. We ignore
232	// EOF errors if they occur because it usually means that SCP prematurely
233	// ended on the other side.
234	if err := f(stdinW, stdoutR); err != nil && err != io.EOF {
235		return err
236	}
237
238	// Close the stdin, which sends an EOF, and then set w to nil so that
239	// our defer func doesn't close it again since that is unsafe with
240	// the Go SSH package.
241	stdinW.Close()
242	stdinW = nil
243
244	// Wait for the SCP connection to close, meaning it has consumed all
245	// our data and has completed. Or has errored.
246	err = session.Wait()
247	if err != nil {
248		if exitErr, ok := err.(*ssh.ExitError); ok {
249			// Otherwise, we have an ExitErorr, meaning we can just read
250			// the exit status
251			c.config.Logger.Error("got non-zero exit status", "exit_status", exitErr.ExitStatus())
252
253			// If we exited with status 127, it means SCP isn't available.
254			// Return a more descriptive error for that.
255			if exitErr.ExitStatus() == 127 {
256				return errors.New(
257					"SCP failed to start. This usually means that SCP is not\n" +
258						"properly installed on the remote system.")
259			}
260		}
261
262		return err
263	}
264	return nil
265}
266
267// checkSCPStatus checks that a prior command sent to SCP completed
268// successfully. If it did not complete successfully, an error will
269// be returned.
270func checkSCPStatus(r *bufio.Reader) error {
271	code, err := r.ReadByte()
272	if err != nil {
273		return err
274	}
275
276	if code != 0 {
277		// Treat any non-zero (really 1 and 2) as fatal errors
278		message, _, err := r.ReadLine()
279		if err != nil {
280			return fmt.Errorf("error reading error message: %w", err)
281		}
282
283		return errors.New(string(message))
284	}
285
286	return nil
287}
288
289func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *os.FileInfo) error {
290	var mode os.FileMode
291	var size int64
292
293	if fi != nil && (*fi).Mode().IsRegular() {
294		mode = (*fi).Mode().Perm()
295		size = (*fi).Size()
296	} else {
297		// Create a temporary file where we can copy the contents of the src
298		// so that we can determine the length, since SCP is length-prefixed.
299		tf, err := ioutil.TempFile("", "vault-ssh-upload")
300		if err != nil {
301			return fmt.Errorf("error creating temporary file for upload: %w", err)
302		}
303		defer os.Remove(tf.Name())
304		defer tf.Close()
305
306		mode = 0o644
307
308		if _, err := io.Copy(tf, src); err != nil {
309			return err
310		}
311
312		// Sync the file so that the contents are definitely on disk, then
313		// read the length of it.
314		if err := tf.Sync(); err != nil {
315			return fmt.Errorf("error creating temporary file for upload: %w", err)
316		}
317
318		// Seek the file to the beginning so we can re-read all of it
319		if _, err := tf.Seek(0, 0); err != nil {
320			return fmt.Errorf("error creating temporary file for upload: %w", err)
321		}
322
323		tfi, err := tf.Stat()
324		if err != nil {
325			return fmt.Errorf("error creating temporary file for upload: %w", err)
326		}
327
328		size = tfi.Size()
329		src = tf
330	}
331
332	// Start the protocol
333	perms := fmt.Sprintf("C%04o", mode)
334
335	fmt.Fprintln(w, perms, size, dst)
336	if err := checkSCPStatus(r); err != nil {
337		return err
338	}
339
340	if _, err := io.CopyN(w, src, size); err != nil {
341		return err
342	}
343
344	fmt.Fprint(w, "\x00")
345	if err := checkSCPStatus(r); err != nil {
346		return err
347	}
348
349	return nil
350}
351