1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssh
6
7import (
8	"bytes"
9	"errors"
10	"fmt"
11	"io"
12	"net"
13	"strings"
14)
15
16// The Permissions type holds fine-grained permissions that are
17// specific to a user or a specific authentication method for a user.
18// The Permissions value for a successful authentication attempt is
19// available in ServerConn, so it can be used to pass information from
20// the user-authentication phase to the application layer.
21type Permissions struct {
22	// CriticalOptions indicate restrictions to the default
23	// permissions, and are typically used in conjunction with
24	// user certificates. The standard for SSH certificates
25	// defines "force-command" (only allow the given command to
26	// execute) and "source-address" (only allow connections from
27	// the given address). The SSH package currently only enforces
28	// the "source-address" critical option. It is up to server
29	// implementations to enforce other critical options, such as
30	// "force-command", by checking them after the SSH handshake
31	// is successful. In general, SSH servers should reject
32	// connections that specify critical options that are unknown
33	// or not supported.
34	CriticalOptions map[string]string
35
36	// Extensions are extra functionality that the server may
37	// offer on authenticated connections. Lack of support for an
38	// extension does not preclude authenticating a user. Common
39	// extensions are "permit-agent-forwarding",
40	// "permit-X11-forwarding". The Go SSH library currently does
41	// not act on any extension, and it is up to server
42	// implementations to honor them. Extensions can be used to
43	// pass data from the authentication callbacks to the server
44	// application layer.
45	Extensions map[string]string
46}
47
48// ServerConfig holds server specific configuration data.
49type ServerConfig struct {
50	// Config contains configuration shared between client and server.
51	Config
52
53	hostKeys []Signer
54
55	// NoClientAuth is true if clients are allowed to connect without
56	// authenticating.
57	NoClientAuth bool
58
59	// MaxAuthTries specifies the maximum number of authentication attempts
60	// permitted per connection. If set to a negative number, the number of
61	// attempts are unlimited. If set to zero, the number of attempts are limited
62	// to 6.
63	MaxAuthTries int
64
65	// PasswordCallback, if non-nil, is called when a user
66	// attempts to authenticate using a password.
67	PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
68
69	// PublicKeyCallback, if non-nil, is called when a client
70	// offers a public key for authentication. It must return a nil error
71	// if the given public key can be used to authenticate the
72	// given user. For example, see CertChecker.Authenticate. A
73	// call to this function does not guarantee that the key
74	// offered is in fact used to authenticate. To record any data
75	// depending on the public key, store it inside a
76	// Permissions.Extensions entry.
77	PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
78
79	// KeyboardInteractiveCallback, if non-nil, is called when
80	// keyboard-interactive authentication is selected (RFC
81	// 4256). The client object's Challenge function should be
82	// used to query the user. The callback may offer multiple
83	// Challenge rounds. To avoid information leaks, the client
84	// should be presented a challenge even if the user is
85	// unknown.
86	KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
87
88	// AuthLogCallback, if non-nil, is called to log all authentication
89	// attempts.
90	AuthLogCallback func(conn ConnMetadata, method string, err error)
91
92	// ServerVersion is the version identification string to announce in
93	// the public handshake.
94	// If empty, a reasonable default is used.
95	// Note that RFC 4253 section 4.2 requires that this string start with
96	// "SSH-2.0-".
97	ServerVersion string
98}
99
100// AddHostKey adds a private key as a host key. If an existing host
101// key exists with the same algorithm, it is overwritten. Each server
102// config must have at least one host key.
103func (s *ServerConfig) AddHostKey(key Signer) {
104	for i, k := range s.hostKeys {
105		if k.PublicKey().Type() == key.PublicKey().Type() {
106			s.hostKeys[i] = key
107			return
108		}
109	}
110
111	s.hostKeys = append(s.hostKeys, key)
112}
113
114// cachedPubKey contains the results of querying whether a public key is
115// acceptable for a user.
116type cachedPubKey struct {
117	user       string
118	pubKeyData []byte
119	result     error
120	perms      *Permissions
121}
122
123const maxCachedPubKeys = 16
124
125// pubKeyCache caches tests for public keys.  Since SSH clients
126// will query whether a public key is acceptable before attempting to
127// authenticate with it, we end up with duplicate queries for public
128// key validity.  The cache only applies to a single ServerConn.
129type pubKeyCache struct {
130	keys []cachedPubKey
131}
132
133// get returns the result for a given user/algo/key tuple.
134func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) {
135	for _, k := range c.keys {
136		if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) {
137			return k, true
138		}
139	}
140	return cachedPubKey{}, false
141}
142
143// add adds the given tuple to the cache.
144func (c *pubKeyCache) add(candidate cachedPubKey) {
145	if len(c.keys) < maxCachedPubKeys {
146		c.keys = append(c.keys, candidate)
147	}
148}
149
150// ServerConn is an authenticated SSH connection, as seen from the
151// server
152type ServerConn struct {
153	Conn
154
155	// If the succeeding authentication callback returned a
156	// non-nil Permissions pointer, it is stored here.
157	Permissions *Permissions
158}
159
160// NewServerConn starts a new SSH server with c as the underlying
161// transport.  It starts with a handshake and, if the handshake is
162// unsuccessful, it closes the connection and returns an error.  The
163// Request and NewChannel channels must be serviced, or the connection
164// will hang.
165func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
166	fullConf := *config
167	fullConf.SetDefaults()
168	if fullConf.MaxAuthTries == 0 {
169		fullConf.MaxAuthTries = 6
170	}
171
172	s := &connection{
173		sshConn: sshConn{conn: c},
174	}
175	perms, err := s.serverHandshake(&fullConf)
176	if err != nil {
177		c.Close()
178		return nil, nil, nil, err
179	}
180	return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil
181}
182
183// signAndMarshal signs the data with the appropriate algorithm,
184// and serializes the result in SSH wire format.
185func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
186	sig, err := k.Sign(rand, data)
187	if err != nil {
188		return nil, err
189	}
190
191	return Marshal(sig), nil
192}
193
194// handshake performs key exchange and user authentication.
195func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) {
196	if len(config.hostKeys) == 0 {
197		return nil, errors.New("ssh: server has no host keys")
198	}
199
200	if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil && config.KeyboardInteractiveCallback == nil {
201		return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
202	}
203
204	if config.ServerVersion != "" {
205		s.serverVersion = []byte(config.ServerVersion)
206	} else {
207		s.serverVersion = []byte(packageVersion)
208	}
209	var err error
210	s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion)
211	if err != nil {
212		return nil, err
213	}
214
215	tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
216	s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
217
218	if err := s.transport.waitSession(); err != nil {
219		return nil, err
220	}
221
222	// We just did the key change, so the session ID is established.
223	s.sessionID = s.transport.getSessionID()
224
225	var packet []byte
226	if packet, err = s.transport.readPacket(); err != nil {
227		return nil, err
228	}
229
230	var serviceRequest serviceRequestMsg
231	if err = Unmarshal(packet, &serviceRequest); err != nil {
232		return nil, err
233	}
234	if serviceRequest.Service != serviceUserAuth {
235		return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
236	}
237	serviceAccept := serviceAcceptMsg{
238		Service: serviceUserAuth,
239	}
240	if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil {
241		return nil, err
242	}
243
244	perms, err := s.serverAuthenticate(config)
245	if err != nil {
246		return nil, err
247	}
248	s.mux = newMux(s.transport)
249	return perms, err
250}
251
252func isAcceptableAlgo(algo string) bool {
253	switch algo {
254	case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519,
255		CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
256		return true
257	}
258	return false
259}
260
261func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
262	if addr == nil {
263		return errors.New("ssh: no address known for client, but source-address match required")
264	}
265
266	tcpAddr, ok := addr.(*net.TCPAddr)
267	if !ok {
268		return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr)
269	}
270
271	for _, sourceAddr := range strings.Split(sourceAddrs, ",") {
272		if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
273			if allowedIP.Equal(tcpAddr.IP) {
274				return nil
275			}
276		} else {
277			_, ipNet, err := net.ParseCIDR(sourceAddr)
278			if err != nil {
279				return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
280			}
281
282			if ipNet.Contains(tcpAddr.IP) {
283				return nil
284			}
285		}
286	}
287
288	return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
289}
290
291// ServerAuthError implements the error interface. It appends any authentication
292// errors that may occur, and is returned if all of the authentication methods
293// provided by the user failed to authenticate.
294type ServerAuthError struct {
295	// Errors contains authentication errors returned by the authentication
296	// callback methods.
297	Errors []error
298}
299
300func (l ServerAuthError) Error() string {
301	var errs []string
302	for _, err := range l.Errors {
303		errs = append(errs, err.Error())
304	}
305	return "[" + strings.Join(errs, ", ") + "]"
306}
307
308func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
309	sessionID := s.transport.getSessionID()
310	var cache pubKeyCache
311	var perms *Permissions
312
313	authFailures := 0
314	var authErrs []error
315
316userAuthLoop:
317	for {
318		if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 {
319			discMsg := &disconnectMsg{
320				Reason:  2,
321				Message: "too many authentication failures",
322			}
323
324			if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
325				return nil, err
326			}
327
328			return nil, discMsg
329		}
330
331		var userAuthReq userAuthRequestMsg
332		if packet, err := s.transport.readPacket(); err != nil {
333			if err == io.EOF {
334				return nil, &ServerAuthError{Errors: authErrs}
335			}
336			return nil, err
337		} else if err = Unmarshal(packet, &userAuthReq); err != nil {
338			return nil, err
339		}
340
341		if userAuthReq.Service != serviceSSH {
342			return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
343		}
344
345		s.user = userAuthReq.User
346		perms = nil
347		authErr := errors.New("no auth passed yet")
348
349		switch userAuthReq.Method {
350		case "none":
351			if config.NoClientAuth {
352				authErr = nil
353			}
354
355			// allow initial attempt of 'none' without penalty
356			if authFailures == 0 {
357				authFailures--
358			}
359		case "password":
360			if config.PasswordCallback == nil {
361				authErr = errors.New("ssh: password auth not configured")
362				break
363			}
364			payload := userAuthReq.Payload
365			if len(payload) < 1 || payload[0] != 0 {
366				return nil, parseError(msgUserAuthRequest)
367			}
368			payload = payload[1:]
369			password, payload, ok := parseString(payload)
370			if !ok || len(payload) > 0 {
371				return nil, parseError(msgUserAuthRequest)
372			}
373
374			perms, authErr = config.PasswordCallback(s, password)
375		case "keyboard-interactive":
376			if config.KeyboardInteractiveCallback == nil {
377				authErr = errors.New("ssh: keyboard-interactive auth not configubred")
378				break
379			}
380
381			prompter := &sshClientKeyboardInteractive{s}
382			perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
383		case "publickey":
384			if config.PublicKeyCallback == nil {
385				authErr = errors.New("ssh: publickey auth not configured")
386				break
387			}
388			payload := userAuthReq.Payload
389			if len(payload) < 1 {
390				return nil, parseError(msgUserAuthRequest)
391			}
392			isQuery := payload[0] == 0
393			payload = payload[1:]
394			algoBytes, payload, ok := parseString(payload)
395			if !ok {
396				return nil, parseError(msgUserAuthRequest)
397			}
398			algo := string(algoBytes)
399			if !isAcceptableAlgo(algo) {
400				authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
401				break
402			}
403
404			pubKeyData, payload, ok := parseString(payload)
405			if !ok {
406				return nil, parseError(msgUserAuthRequest)
407			}
408
409			pubKey, err := ParsePublicKey(pubKeyData)
410			if err != nil {
411				return nil, err
412			}
413
414			candidate, ok := cache.get(s.user, pubKeyData)
415			if !ok {
416				candidate.user = s.user
417				candidate.pubKeyData = pubKeyData
418				candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
419				if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
420					candidate.result = checkSourceAddress(
421						s.RemoteAddr(),
422						candidate.perms.CriticalOptions[sourceAddressCriticalOption])
423				}
424				cache.add(candidate)
425			}
426
427			if isQuery {
428				// The client can query if the given public key
429				// would be okay.
430
431				if len(payload) > 0 {
432					return nil, parseError(msgUserAuthRequest)
433				}
434
435				if candidate.result == nil {
436					okMsg := userAuthPubKeyOkMsg{
437						Algo:   algo,
438						PubKey: pubKeyData,
439					}
440					if err = s.transport.writePacket(Marshal(&okMsg)); err != nil {
441						return nil, err
442					}
443					continue userAuthLoop
444				}
445				authErr = candidate.result
446			} else {
447				sig, payload, ok := parseSignature(payload)
448				if !ok || len(payload) > 0 {
449					return nil, parseError(msgUserAuthRequest)
450				}
451				// Ensure the public key algo and signature algo
452				// are supported.  Compare the private key
453				// algorithm name that corresponds to algo with
454				// sig.Format.  This is usually the same, but
455				// for certs, the names differ.
456				if !isAcceptableAlgo(sig.Format) {
457					break
458				}
459				signedData := buildDataSignedForAuth(sessionID, userAuthReq, algoBytes, pubKeyData)
460
461				if err := pubKey.Verify(signedData, sig); err != nil {
462					return nil, err
463				}
464
465				authErr = candidate.result
466				perms = candidate.perms
467			}
468		default:
469			authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method)
470		}
471
472		authErrs = append(authErrs, authErr)
473
474		if config.AuthLogCallback != nil {
475			config.AuthLogCallback(s, userAuthReq.Method, authErr)
476		}
477
478		if authErr == nil {
479			break userAuthLoop
480		}
481
482		authFailures++
483
484		var failureMsg userAuthFailureMsg
485		if config.PasswordCallback != nil {
486			failureMsg.Methods = append(failureMsg.Methods, "password")
487		}
488		if config.PublicKeyCallback != nil {
489			failureMsg.Methods = append(failureMsg.Methods, "publickey")
490		}
491		if config.KeyboardInteractiveCallback != nil {
492			failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
493		}
494
495		if len(failureMsg.Methods) == 0 {
496			return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
497		}
498
499		if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {
500			return nil, err
501		}
502	}
503
504	if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
505		return nil, err
506	}
507	return perms, nil
508}
509
510// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
511// asking the client on the other side of a ServerConn.
512type sshClientKeyboardInteractive struct {
513	*connection
514}
515
516func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
517	if len(questions) != len(echos) {
518		return nil, errors.New("ssh: echos and questions must have equal length")
519	}
520
521	var prompts []byte
522	for i := range questions {
523		prompts = appendString(prompts, questions[i])
524		prompts = appendBool(prompts, echos[i])
525	}
526
527	if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{
528		Instruction: instruction,
529		NumPrompts:  uint32(len(questions)),
530		Prompts:     prompts,
531	})); err != nil {
532		return nil, err
533	}
534
535	packet, err := c.transport.readPacket()
536	if err != nil {
537		return nil, err
538	}
539	if packet[0] != msgUserAuthInfoResponse {
540		return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0])
541	}
542	packet = packet[1:]
543
544	n, packet, ok := parseUint32(packet)
545	if !ok || int(n) != len(questions) {
546		return nil, parseError(msgUserAuthInfoResponse)
547	}
548
549	for i := uint32(0); i < n; i++ {
550		ans, rest, ok := parseString(packet)
551		if !ok {
552			return nil, parseError(msgUserAuthInfoResponse)
553		}
554
555		answers = append(answers, string(ans))
556		packet = rest
557	}
558	if len(packet) != 0 {
559		return nil, errors.New("ssh: junk at end of message")
560	}
561
562	return answers, nil
563}
564