1// Copyright 2017 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Tests for ssh client multi-auth
6//
7// These tests run a simple go ssh client against OpenSSH server
8// over unix domain sockets. The tests use multiple combinations
9// of password, keyboard-interactive and publickey authentication
10// methods.
11//
12// A wrapper library for making sshd PAM authentication use test
13// passwords is required in ./sshd_test_pw.so. If the library does
14// not exist these tests will be skipped. See compile instructions
15// (for linux) in file ./sshd_test_pw.c.
16
17// +build linux
18
19package test
20
21import (
22	"fmt"
23	"strings"
24	"testing"
25
26	"golang.org/x/crypto/ssh"
27)
28
29// test cases
30type multiAuthTestCase struct {
31	authMethods         []string
32	expectedPasswordCbs int
33	expectedKbdIntCbs   int
34}
35
36// test context
37type multiAuthTestCtx struct {
38	password       string
39	numPasswordCbs int
40	numKbdIntCbs   int
41}
42
43// create test context
44func newMultiAuthTestCtx(t *testing.T) *multiAuthTestCtx {
45	password, err := randomPassword()
46	if err != nil {
47		t.Fatalf("Failed to generate random test password: %s", err.Error())
48	}
49
50	return &multiAuthTestCtx{
51		password: password,
52	}
53}
54
55// password callback
56func (ctx *multiAuthTestCtx) passwordCb() (secret string, err error) {
57	ctx.numPasswordCbs++
58	return ctx.password, nil
59}
60
61// keyboard-interactive callback
62func (ctx *multiAuthTestCtx) kbdIntCb(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
63	if len(questions) == 0 {
64		return nil, nil
65	}
66
67	ctx.numKbdIntCbs++
68	if len(questions) == 1 {
69		return []string{ctx.password}, nil
70	}
71
72	return nil, fmt.Errorf("unsupported keyboard-interactive flow")
73}
74
75// TestMultiAuth runs several subtests for different combinations of password, keyboard-interactive and publickey authentication methods
76func TestMultiAuth(t *testing.T) {
77	testCases := []multiAuthTestCase{
78		// Test password,publickey authentication, assert that password callback is called 1 time
79		multiAuthTestCase{
80			authMethods:         []string{"password", "publickey"},
81			expectedPasswordCbs: 1,
82		},
83		// Test keyboard-interactive,publickey authentication, assert that keyboard-interactive callback is called 1 time
84		multiAuthTestCase{
85			authMethods:       []string{"keyboard-interactive", "publickey"},
86			expectedKbdIntCbs: 1,
87		},
88		// Test publickey,password authentication, assert that password callback is called 1 time
89		multiAuthTestCase{
90			authMethods:         []string{"publickey", "password"},
91			expectedPasswordCbs: 1,
92		},
93		// Test publickey,keyboard-interactive authentication, assert that keyboard-interactive callback is called 1 time
94		multiAuthTestCase{
95			authMethods:       []string{"publickey", "keyboard-interactive"},
96			expectedKbdIntCbs: 1,
97		},
98		// Test password,password authentication, assert that password callback is called 2 times
99		multiAuthTestCase{
100			authMethods:         []string{"password", "password"},
101			expectedPasswordCbs: 2,
102		},
103	}
104
105	for _, testCase := range testCases {
106		t.Run(strings.Join(testCase.authMethods, ","), func(t *testing.T) {
107			ctx := newMultiAuthTestCtx(t)
108
109			server := newServerForConfig(t, "MultiAuth", map[string]string{"AuthMethods": strings.Join(testCase.authMethods, ",")})
110			defer server.Shutdown()
111
112			clientConfig := clientConfig()
113			server.setTestPassword(clientConfig.User, ctx.password)
114
115			publicKeyAuthMethod := clientConfig.Auth[0]
116			clientConfig.Auth = nil
117			for _, authMethod := range testCase.authMethods {
118				switch authMethod {
119				case "publickey":
120					clientConfig.Auth = append(clientConfig.Auth, publicKeyAuthMethod)
121				case "password":
122					clientConfig.Auth = append(clientConfig.Auth,
123						ssh.RetryableAuthMethod(ssh.PasswordCallback(ctx.passwordCb), 5))
124				case "keyboard-interactive":
125					clientConfig.Auth = append(clientConfig.Auth,
126						ssh.RetryableAuthMethod(ssh.KeyboardInteractive(ctx.kbdIntCb), 5))
127				default:
128					t.Fatalf("Unknown authentication method %s", authMethod)
129				}
130			}
131
132			conn := server.Dial(clientConfig)
133			defer conn.Close()
134
135			if ctx.numPasswordCbs != testCase.expectedPasswordCbs {
136				t.Fatalf("passwordCallback was called %d times, expected %d times", ctx.numPasswordCbs, testCase.expectedPasswordCbs)
137			}
138
139			if ctx.numKbdIntCbs != testCase.expectedKbdIntCbs {
140				t.Fatalf("keyboardInteractiveCallback was called %d times, expected %d times", ctx.numKbdIntCbs, testCase.expectedKbdIntCbs)
141			}
142		})
143	}
144}
145