1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package agent
6
7import (
8	"errors"
9	"io"
10	"net"
11	"sync"
12
13	"golang.org/x/crypto/ssh"
14)
15
16// RequestAgentForwarding sets up agent forwarding for the session.
17// ForwardToAgent or ForwardToRemote should be called to route
18// the authentication requests.
19func RequestAgentForwarding(session *ssh.Session) error {
20	ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil)
21	if err != nil {
22		return err
23	}
24	if !ok {
25		return errors.New("forwarding request denied")
26	}
27	return nil
28}
29
30// ForwardToAgent routes authentication requests to the given keyring.
31func ForwardToAgent(client *ssh.Client, keyring Agent) error {
32	channels := client.HandleChannelOpen(channelType)
33	if channels == nil {
34		return errors.New("agent: already have handler for " + channelType)
35	}
36
37	go func() {
38		for ch := range channels {
39			channel, reqs, err := ch.Accept()
40			if err != nil {
41				continue
42			}
43			go ssh.DiscardRequests(reqs)
44			go func() {
45				ServeAgent(keyring, channel)
46				channel.Close()
47			}()
48		}
49	}()
50	return nil
51}
52
53const channelType = "auth-agent@openssh.com"
54
55// ForwardToRemote routes authentication requests to the ssh-agent
56// process serving on the given unix socket.
57func ForwardToRemote(client *ssh.Client, addr string) error {
58	channels := client.HandleChannelOpen(channelType)
59	if channels == nil {
60		return errors.New("agent: already have handler for " + channelType)
61	}
62	conn, err := net.Dial("unix", addr)
63	if err != nil {
64		return err
65	}
66	conn.Close()
67
68	go func() {
69		for ch := range channels {
70			channel, reqs, err := ch.Accept()
71			if err != nil {
72				continue
73			}
74			go ssh.DiscardRequests(reqs)
75			go forwardUnixSocket(channel, addr)
76		}
77	}()
78	return nil
79}
80
81func forwardUnixSocket(channel ssh.Channel, addr string) {
82	conn, err := net.Dial("unix", addr)
83	if err != nil {
84		return
85	}
86
87	var wg sync.WaitGroup
88	wg.Add(2)
89	go func() {
90		io.Copy(conn, channel)
91		conn.(*net.UnixConn).CloseWrite()
92		wg.Done()
93	}()
94	go func() {
95		io.Copy(channel, conn)
96		channel.CloseWrite()
97		wg.Done()
98	}()
99
100	wg.Wait()
101	conn.Close()
102	channel.Close()
103}
104