1package ssh 2 3import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "net" 11 "os" 12 "path/filepath" 13 14 log "github.com/hashicorp/go-hclog" 15 16 "golang.org/x/crypto/ssh" 17 "golang.org/x/crypto/ssh/agent" 18) 19 20type comm struct { 21 client *ssh.Client 22 config *SSHCommConfig 23 conn net.Conn 24 address string 25} 26 27// SSHCommConfig is the structure used to configure the SSH communicator. 28type SSHCommConfig struct { 29 // The configuration of the Go SSH connection 30 SSHConfig *ssh.ClientConfig 31 32 // Connection returns a new connection. The current connection 33 // in use will be closed as part of the Close method, or in the 34 // case an error occurs. 35 Connection func() (net.Conn, error) 36 37 // Pty, if true, will request a pty from the remote end. 38 Pty bool 39 40 // DisableAgent, if true, will not forward the SSH agent. 41 DisableAgent bool 42 43 // Logger for output 44 Logger log.Logger 45} 46 47// Creates a new communicator implementation over SSH. This takes 48// an already existing TCP connection and SSH configuration. 49func SSHCommNew(address string, config *SSHCommConfig) (result *comm, err error) { 50 // Establish an initial connection and connect 51 result = &comm{ 52 config: config, 53 address: address, 54 } 55 56 if err = result.reconnect(); err != nil { 57 result = nil 58 return 59 } 60 61 return 62} 63 64func (c *comm) Close() error { 65 var err error 66 if c.conn != nil { 67 err = c.conn.Close() 68 } 69 c.conn = nil 70 c.client = nil 71 return err 72} 73 74func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error { 75 // The target directory and file for talking the SCP protocol 76 target_dir := filepath.Dir(path) 77 target_file := filepath.Base(path) 78 79 // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 80 // This does not work when the target host is unix. Switch to forward slash 81 // which works for unix and windows 82 target_dir = filepath.ToSlash(target_dir) 83 84 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 85 return scpUploadFile(target_file, input, w, stdoutR, fi) 86 } 87 88 return c.scpSession("scp -vt "+target_dir, scpFunc) 89} 90 91func (c *comm) NewSession() (session *ssh.Session, err error) { 92 if c.client == nil { 93 err = errors.New("client not available") 94 } else { 95 session, err = c.client.NewSession() 96 } 97 98 if err != nil { 99 c.config.Logger.Error("ssh session open error, attempting reconnect", "error", err) 100 if err := c.reconnect(); err != nil { 101 c.config.Logger.Error("reconnect attempt failed", "error", err) 102 return nil, err 103 } 104 105 return c.client.NewSession() 106 } 107 108 return session, nil 109} 110 111func (c *comm) reconnect() error { 112 // Close previous connection. 113 if c.conn != nil { 114 c.Close() 115 } 116 117 var err error 118 c.conn, err = c.config.Connection() 119 if err != nil { 120 // Explicitly set this to the REAL nil. Connection() can return 121 // a nil implementation of net.Conn which will make the 122 // "if c.conn == nil" check fail above. Read here for more information 123 // on this psychotic language feature: 124 // 125 // http://golang.org/doc/faq#nil_error 126 c.conn = nil 127 c.config.Logger.Error("reconnection error", "error", err) 128 return err 129 } 130 131 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig) 132 if err != nil { 133 c.config.Logger.Error("handshake error", "error", err) 134 c.Close() 135 return err 136 } 137 if sshConn != nil { 138 c.client = ssh.NewClient(sshConn, sshChan, req) 139 } 140 c.connectToAgent() 141 142 return nil 143} 144 145func (c *comm) connectToAgent() { 146 if c.client == nil { 147 return 148 } 149 150 if c.config.DisableAgent { 151 return 152 } 153 154 // open connection to the local agent 155 socketLocation := os.Getenv("SSH_AUTH_SOCK") 156 if socketLocation == "" { 157 return 158 } 159 agentConn, err := net.Dial("unix", socketLocation) 160 if err != nil { 161 c.config.Logger.Error("could not connect to local agent socket", "socket_path", socketLocation) 162 return 163 } 164 defer agentConn.Close() 165 166 // create agent and add in auth 167 forwardingAgent := agent.NewClient(agentConn) 168 if forwardingAgent == nil { 169 c.config.Logger.Error("could not create agent client") 170 return 171 } 172 173 // add callback for forwarding agent to SSH config 174 // XXX - might want to handle reconnects appending multiple callbacks 175 auth := ssh.PublicKeysCallback(forwardingAgent.Signers) 176 c.config.SSHConfig.Auth = append(c.config.SSHConfig.Auth, auth) 177 agent.ForwardToAgent(c.client, forwardingAgent) 178 179 // Setup a session to request agent forwarding 180 session, err := c.NewSession() 181 if err != nil { 182 return 183 } 184 defer session.Close() 185 186 err = agent.RequestAgentForwarding(session) 187 if err != nil { 188 c.config.Logger.Error("error requesting agent forwarding", "error", err) 189 return 190 } 191 return 192} 193 194func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 195 session, err := c.NewSession() 196 if err != nil { 197 return err 198 } 199 defer session.Close() 200 201 // Get a pipe to stdin so that we can send data down 202 stdinW, err := session.StdinPipe() 203 if err != nil { 204 return err 205 } 206 207 // We only want to close once, so we nil w after we close it, 208 // and only close in the defer if it hasn't been closed already. 209 defer func() { 210 if stdinW != nil { 211 stdinW.Close() 212 } 213 }() 214 215 // Get a pipe to stdout so that we can get responses back 216 stdoutPipe, err := session.StdoutPipe() 217 if err != nil { 218 return err 219 } 220 stdoutR := bufio.NewReader(stdoutPipe) 221 222 // Set stderr to a bytes buffer 223 stderr := new(bytes.Buffer) 224 session.Stderr = stderr 225 226 // Start the sink mode on the other side 227 if err := session.Start(scpCommand); err != nil { 228 return err 229 } 230 231 // Call our callback that executes in the context of SCP. We ignore 232 // EOF errors if they occur because it usually means that SCP prematurely 233 // ended on the other side. 234 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 235 return err 236 } 237 238 // Close the stdin, which sends an EOF, and then set w to nil so that 239 // our defer func doesn't close it again since that is unsafe with 240 // the Go SSH package. 241 stdinW.Close() 242 stdinW = nil 243 244 // Wait for the SCP connection to close, meaning it has consumed all 245 // our data and has completed. Or has errored. 246 err = session.Wait() 247 if err != nil { 248 if exitErr, ok := err.(*ssh.ExitError); ok { 249 // Otherwise, we have an ExitErorr, meaning we can just read 250 // the exit status 251 c.config.Logger.Error("got non-zero exit status", "exit_status", exitErr.ExitStatus()) 252 253 // If we exited with status 127, it means SCP isn't available. 254 // Return a more descriptive error for that. 255 if exitErr.ExitStatus() == 127 { 256 return errors.New( 257 "SCP failed to start. This usually means that SCP is not\n" + 258 "properly installed on the remote system.") 259 } 260 } 261 262 return err 263 } 264 return nil 265} 266 267// checkSCPStatus checks that a prior command sent to SCP completed 268// successfully. If it did not complete successfully, an error will 269// be returned. 270func checkSCPStatus(r *bufio.Reader) error { 271 code, err := r.ReadByte() 272 if err != nil { 273 return err 274 } 275 276 if code != 0 { 277 // Treat any non-zero (really 1 and 2) as fatal errors 278 message, _, err := r.ReadLine() 279 if err != nil { 280 return fmt.Errorf("error reading error message: %w", err) 281 } 282 283 return errors.New(string(message)) 284 } 285 286 return nil 287} 288 289func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *os.FileInfo) error { 290 var mode os.FileMode 291 var size int64 292 293 if fi != nil && (*fi).Mode().IsRegular() { 294 mode = (*fi).Mode().Perm() 295 size = (*fi).Size() 296 } else { 297 // Create a temporary file where we can copy the contents of the src 298 // so that we can determine the length, since SCP is length-prefixed. 299 tf, err := ioutil.TempFile("", "vault-ssh-upload") 300 if err != nil { 301 return fmt.Errorf("error creating temporary file for upload: %w", err) 302 } 303 defer os.Remove(tf.Name()) 304 defer tf.Close() 305 306 mode = 0o644 307 308 if _, err := io.Copy(tf, src); err != nil { 309 return err 310 } 311 312 // Sync the file so that the contents are definitely on disk, then 313 // read the length of it. 314 if err := tf.Sync(); err != nil { 315 return fmt.Errorf("error creating temporary file for upload: %w", err) 316 } 317 318 // Seek the file to the beginning so we can re-read all of it 319 if _, err := tf.Seek(0, 0); err != nil { 320 return fmt.Errorf("error creating temporary file for upload: %w", err) 321 } 322 323 tfi, err := tf.Stat() 324 if err != nil { 325 return fmt.Errorf("error creating temporary file for upload: %w", err) 326 } 327 328 size = tfi.Size() 329 src = tf 330 } 331 332 // Start the protocol 333 perms := fmt.Sprintf("C%04o", mode) 334 335 fmt.Fprintln(w, perms, size, dst) 336 if err := checkSCPStatus(r); err != nil { 337 return err 338 } 339 340 if _, err := io.CopyN(w, src, size); err != nil { 341 return err 342 } 343 344 fmt.Fprint(w, "\x00") 345 if err := checkSCPStatus(r); err != nil { 346 return err 347 } 348 349 return nil 350} 351