1package ssh
2
3import (
4	"fmt"
5
6	bosherr "github.com/cloudfoundry/bosh-utils/errors"
7	boshsys "github.com/cloudfoundry/bosh-utils/system"
8
9	boshdir "github.com/cloudfoundry/bosh-cli/director"
10)
11
12type SessionImpl struct {
13	connOpts ConnectionOpts
14	sessOpts SessionImplOpts
15	result   boshdir.SSHResult
16
17	privKeyFile    boshsys.File
18	knownHostsFile boshsys.File
19
20	fs boshsys.FileSystem
21}
22
23type SessionImplOpts struct {
24	ForceTTY bool
25}
26
27func NewSessionImpl(
28	connOpts ConnectionOpts,
29	sessOpts SessionImplOpts,
30	result boshdir.SSHResult,
31	fs boshsys.FileSystem,
32) *SessionImpl {
33	return &SessionImpl{connOpts: connOpts, sessOpts: sessOpts, result: result, fs: fs}
34}
35
36func (r *SessionImpl) Start() (SSHArgs, error) {
37	var err error
38
39	r.privKeyFile, err = r.makePrivKeyFile()
40	if err != nil {
41		return SSHArgs{}, err
42	}
43
44	r.knownHostsFile, err = r.makeKnownHostsFile()
45	if err != nil {
46		_ = r.fs.RemoveAll(r.privKeyFile.Name())
47		return SSHArgs{}, err
48	}
49
50	args := NewSSHArgs(
51		r.connOpts,
52		r.result,
53		r.sessOpts.ForceTTY,
54		r.privKeyFile,
55		r.knownHostsFile,
56	)
57
58	return args, nil
59}
60
61func (r *SessionImpl) Finish() error {
62	// Make sure to try to delete all files regardless of errors
63	privKeyErr := r.fs.RemoveAll(r.privKeyFile.Name())
64	knownHostsErr := r.fs.RemoveAll(r.knownHostsFile.Name())
65
66	if privKeyErr != nil {
67		return privKeyErr
68	}
69
70	if knownHostsErr != nil {
71		return knownHostsErr
72	}
73
74	return nil
75}
76
77func (r SessionImpl) makePrivKeyFile() (boshsys.File, error) {
78	file, err := r.fs.TempFile("ssh-priv-key")
79	if err != nil {
80		return nil, bosherr.WrapErrorf(err, "Creating temp file for SSH private key")
81	}
82
83	_, err = file.Write([]byte(r.connOpts.PrivateKey))
84	if err != nil {
85		_ = r.fs.RemoveAll(file.Name())
86		return nil, bosherr.WrapErrorf(err, "Writing SSH private key")
87	}
88
89	return file, nil
90}
91
92func (r SessionImpl) makeKnownHostsFile() (boshsys.File, error) {
93	file, err := r.fs.TempFile("ssh-known-hosts")
94	if err != nil {
95		return nil, bosherr.WrapErrorf(err, "Creating temp file for SSH known hosts")
96	}
97
98	var content string
99
100	for _, host := range r.result.Hosts {
101		if len(host.HostPublicKey) > 0 {
102			content += fmt.Sprintf("%s %s\n", host.Host, host.HostPublicKey)
103		}
104	}
105
106	if len(content) > 0 {
107		_, err := file.Write([]byte(content))
108		if err != nil {
109			_ = r.fs.RemoveAll(file.Name())
110			return nil, bosherr.WrapErrorf(err, "Writing SSH known hosts")
111		}
112	}
113
114	return file, nil
115}
116