1// +build darwin
2// +build amd64
3
4package sshvault
5
6import (
7	"fmt"
8	"io/ioutil"
9	"os"
10	"path/filepath"
11	"strings"
12	"syscall"
13	"testing"
14
15	"github.com/kr/pty"
16	"github.com/ssh-vault/go-keychain"
17)
18
19func InjectKeychainPassword(path, pw string) error {
20	item := keychain.NewItem()
21	item.SetSecClass(keychain.SecClassGenericPassword)
22	item.SetLabel(fmt.Sprintf("SSH: %s", path))
23	item.SetService("SSH")
24	item.SetAccount(path)
25	item.SetData([]byte(pw))
26	item.SetSynchronizable(keychain.SynchronizableNo)
27
28	return keychain.AddItem(item)
29}
30
31func DeleteKeychainPassword(path string) error {
32	item := keychain.NewItem()
33	item.SetSecClass(keychain.SecClassGenericPassword)
34	item.SetService("SSH")
35	item.SetAccount(path)
36
37	return keychain.DeleteItem(item)
38}
39
40func TestKeychain(t *testing.T) {
41	keyPw := "argle-bargle"
42	keyBadPw := "totally-bogus\n"
43
44	dir, err := ioutil.TempDir("", "vault")
45	if err != nil {
46		t.Error(err)
47	}
48	defer os.RemoveAll(dir) // clean up
49
50	tmpfile := filepath.Join(dir, "vault")
51
52	vault, err := New("", "test_data/id_rsa.pub", "", "create", tmpfile)
53	if err != nil {
54		t.Error(err)
55	}
56	keyPath, err := filepath.Abs(vault.key)
57	if err != nil {
58		t.Errorf("Error finding private key: %s", err)
59	}
60	err = InjectKeychainPassword(keyPath, keyPw)
61	if err != nil {
62		t.Errorf("Error setting up keychain for testing: %s", err)
63	}
64	defer DeleteKeychainPassword(keyPath) // clean up
65
66	pty, tty, err := pty.Open()
67	if err != nil {
68		t.Errorf("Unable to open pty: %s", err)
69	}
70
71	// File Descriptor magic. GetPasswordPrompt() reads the password
72	// from stdin. For the test, we save stdin to a spare FD,
73	// point stdin at the file, run the system under test, and
74	// finally restore the original stdin
75	oldStdin, _ := syscall.Dup(int(syscall.Stdin))
76	oldStdout, _ := syscall.Dup(int(syscall.Stdout))
77	syscall.Dup2(int(tty.Fd()), int(syscall.Stdin))
78	syscall.Dup2(int(tty.Fd()), int(syscall.Stdout))
79
80	go PtyWriteback(pty, keyBadPw)
81
82	keyPwTest, err := vault.GetPassword()
83
84	syscall.Dup2(oldStdin, int(syscall.Stdin))
85	syscall.Dup2(oldStdout, int(syscall.Stdout))
86
87	if err != nil {
88		t.Error(err)
89	}
90	if strings.Trim(string(keyPwTest), "\n") == strings.Trim(keyBadPw, "\n") {
91		t.Errorf("PTY-based password prompt used, not keychain!")
92	}
93
94	if strings.Trim(string(keyPwTest), "\n") != strings.Trim(keyPw, "\n") {
95		t.Errorf("keychain error: %s expected %s, got %s\n", keyPath, keyPw, keyPwTest)
96	}
97
98}
99