1// +build darwin 2// +build amd64 3 4package sshvault 5 6import ( 7 "fmt" 8 "io/ioutil" 9 "os" 10 "path/filepath" 11 "strings" 12 "syscall" 13 "testing" 14 15 "github.com/kr/pty" 16 "github.com/ssh-vault/go-keychain" 17) 18 19func InjectKeychainPassword(path, pw string) error { 20 item := keychain.NewItem() 21 item.SetSecClass(keychain.SecClassGenericPassword) 22 item.SetLabel(fmt.Sprintf("SSH: %s", path)) 23 item.SetService("SSH") 24 item.SetAccount(path) 25 item.SetData([]byte(pw)) 26 item.SetSynchronizable(keychain.SynchronizableNo) 27 28 return keychain.AddItem(item) 29} 30 31func DeleteKeychainPassword(path string) error { 32 item := keychain.NewItem() 33 item.SetSecClass(keychain.SecClassGenericPassword) 34 item.SetService("SSH") 35 item.SetAccount(path) 36 37 return keychain.DeleteItem(item) 38} 39 40func TestKeychain(t *testing.T) { 41 keyPw := "argle-bargle" 42 keyBadPw := "totally-bogus\n" 43 44 dir, err := ioutil.TempDir("", "vault") 45 if err != nil { 46 t.Error(err) 47 } 48 defer os.RemoveAll(dir) // clean up 49 50 tmpfile := filepath.Join(dir, "vault") 51 52 vault, err := New("", "test_data/id_rsa.pub", "", "create", tmpfile) 53 if err != nil { 54 t.Error(err) 55 } 56 keyPath, err := filepath.Abs(vault.key) 57 if err != nil { 58 t.Errorf("Error finding private key: %s", err) 59 } 60 err = InjectKeychainPassword(keyPath, keyPw) 61 if err != nil { 62 t.Errorf("Error setting up keychain for testing: %s", err) 63 } 64 defer DeleteKeychainPassword(keyPath) // clean up 65 66 pty, tty, err := pty.Open() 67 if err != nil { 68 t.Errorf("Unable to open pty: %s", err) 69 } 70 71 // File Descriptor magic. GetPasswordPrompt() reads the password 72 // from stdin. For the test, we save stdin to a spare FD, 73 // point stdin at the file, run the system under test, and 74 // finally restore the original stdin 75 oldStdin, _ := syscall.Dup(int(syscall.Stdin)) 76 oldStdout, _ := syscall.Dup(int(syscall.Stdout)) 77 syscall.Dup2(int(tty.Fd()), int(syscall.Stdin)) 78 syscall.Dup2(int(tty.Fd()), int(syscall.Stdout)) 79 80 go PtyWriteback(pty, keyBadPw) 81 82 keyPwTest, err := vault.GetPassword() 83 84 syscall.Dup2(oldStdin, int(syscall.Stdin)) 85 syscall.Dup2(oldStdout, int(syscall.Stdout)) 86 87 if err != nil { 88 t.Error(err) 89 } 90 if strings.Trim(string(keyPwTest), "\n") == strings.Trim(keyBadPw, "\n") { 91 t.Errorf("PTY-based password prompt used, not keychain!") 92 } 93 94 if strings.Trim(string(keyPwTest), "\n") != strings.Trim(keyPw, "\n") { 95 t.Errorf("keychain error: %s expected %s, got %s\n", keyPath, keyPw, keyPwTest) 96 } 97 98} 99