1package git
2
3import (
4	"context"
5	"fmt"
6	"os"
7	"path/filepath"
8	"strings"
9
10	"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
11)
12
13// BuildSSHInvocation builds a command line to invoke SSH with the provided key and known hosts.
14// Both are optional.
15func BuildSSHInvocation(ctx context.Context, sshKey, knownHosts string) (string, func(), error) {
16	const sshCommand = "ssh"
17	if sshKey == "" && knownHosts == "" {
18		return sshCommand, func() {}, nil
19	}
20
21	tmpDir, err := os.MkdirTemp("", "gitaly-ssh-invocation")
22	if err != nil {
23		return "", func() {}, fmt.Errorf("create temporary directory: %w", err)
24	}
25
26	cleanup := func() {
27		if err := os.RemoveAll(tmpDir); err != nil {
28			ctxlogrus.Extract(ctx).WithError(err).Error("failed to remove tmp directory with ssh key/config")
29		}
30	}
31
32	args := []string{sshCommand}
33	if sshKey != "" {
34		sshKeyFile := filepath.Join(tmpDir, "ssh-key")
35		if err := os.WriteFile(sshKeyFile, []byte(sshKey), 0o400); err != nil {
36			cleanup()
37			return "", nil, fmt.Errorf("create ssh key file: %w", err)
38		}
39
40		args = append(args, "-oIdentitiesOnly=yes", "-oIdentityFile="+sshKeyFile)
41	}
42
43	if knownHosts != "" {
44		knownHostsFile := filepath.Join(tmpDir, "known-hosts")
45		if err := os.WriteFile(knownHostsFile, []byte(knownHosts), 0o400); err != nil {
46			cleanup()
47			return "", nil, fmt.Errorf("create known hosts file: %w", err)
48		}
49
50		args = append(args, "-oStrictHostKeyChecking=yes", "-oCheckHostIP=no", "-oUserKnownHostsFile="+knownHostsFile)
51	}
52
53	return strings.Join(args, " "), cleanup, nil
54}
55