1package sftp
2
3import (
4	"encoding"
5	"fmt"
6	"io"
7	"sync"
8)
9
10// conn implements a bidirectional channel on which client and server
11// connections are multiplexed.
12type conn struct {
13	io.Reader
14	io.WriteCloser
15	// this is the same allocator used in packet manager
16	alloc      *allocator
17	sync.Mutex // used to serialise writes to sendPacket
18}
19
20// the orderID is used in server mode if the allocator is enabled.
21// For the client mode just pass 0
22func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
23	return recvPacket(c, c.alloc, orderID)
24}
25
26func (c *conn) sendPacket(m encoding.BinaryMarshaler) error {
27	c.Lock()
28	defer c.Unlock()
29
30	return sendPacket(c, m)
31}
32
33func (c *conn) Close() error {
34	c.Lock()
35	defer c.Unlock()
36	return c.WriteCloser.Close()
37}
38
39type clientConn struct {
40	conn
41	wg sync.WaitGroup
42
43	sync.Mutex                          // protects inflight
44	inflight   map[uint32]chan<- result // outstanding requests
45
46	closed chan struct{}
47	err    error
48}
49
50// Wait blocks until the conn has shut down, and return the error
51// causing the shutdown. It can be called concurrently from multiple
52// goroutines.
53func (c *clientConn) Wait() error {
54	<-c.closed
55	return c.err
56}
57
58// Close closes the SFTP session.
59func (c *clientConn) Close() error {
60	defer c.wg.Wait()
61	return c.conn.Close()
62}
63
64func (c *clientConn) loop() {
65	defer c.wg.Done()
66	err := c.recv()
67	if err != nil {
68		c.broadcastErr(err)
69	}
70}
71
72// recv continuously reads from the server and forwards responses to the
73// appropriate channel.
74func (c *clientConn) recv() error {
75	defer c.conn.Close()
76
77	for {
78		typ, data, err := c.recvPacket(0)
79		if err != nil {
80			return err
81		}
82		sid, _, err := unmarshalUint32Safe(data)
83		if err != nil {
84			return err
85		}
86
87		ch, ok := c.getChannel(sid)
88		if !ok {
89			// This is an unexpected occurrence. Send the error
90			// back to all listeners so that they terminate
91			// gracefully.
92			return fmt.Errorf("sid not found: %d", sid)
93		}
94
95		ch <- result{typ: typ, data: data}
96	}
97}
98
99func (c *clientConn) putChannel(ch chan<- result, sid uint32) bool {
100	c.Lock()
101	defer c.Unlock()
102
103	select {
104	case <-c.closed:
105		// already closed with broadcastErr, return error on chan.
106		ch <- result{err: ErrSSHFxConnectionLost}
107		return false
108	default:
109	}
110
111	c.inflight[sid] = ch
112	return true
113}
114
115func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) {
116	c.Lock()
117	defer c.Unlock()
118
119	ch, ok := c.inflight[sid]
120	delete(c.inflight, sid)
121
122	return ch, ok
123}
124
125// result captures the result of receiving the a packet from the server
126type result struct {
127	typ  byte
128	data []byte
129	err  error
130}
131
132type idmarshaler interface {
133	id() uint32
134	encoding.BinaryMarshaler
135}
136
137func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) {
138	if cap(ch) < 1 {
139		ch = make(chan result, 1)
140	}
141
142	c.dispatchRequest(ch, p)
143	s := <-ch
144	return s.typ, s.data, s.err
145}
146
147// dispatchRequest should ideally only be called by race-detection tests outside of this file,
148// where you have to ensure two packets are in flight sequentially after each other.
149func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) {
150	sid := p.id()
151
152	if !c.putChannel(ch, sid) {
153		// already closed.
154		return
155	}
156
157	if err := c.conn.sendPacket(p); err != nil {
158		if ch, ok := c.getChannel(sid); ok {
159			ch <- result{err: err}
160		}
161	}
162}
163
164// broadcastErr sends an error to all goroutines waiting for a response.
165func (c *clientConn) broadcastErr(err error) {
166	c.Lock()
167	defer c.Unlock()
168
169	bcastRes := result{err: ErrSSHFxConnectionLost}
170	for sid, ch := range c.inflight {
171		ch <- bcastRes
172
173		// Replace the chan in inflight,
174		// we have hijacked this chan,
175		// and this guarantees always-only-once sending.
176		c.inflight[sid] = make(chan<- result, 1)
177	}
178
179	c.err = err
180	close(c.closed)
181}
182
183type serverConn struct {
184	conn
185}
186
187func (s *serverConn) sendError(id uint32, err error) error {
188	return s.sendPacket(statusFromError(id, err))
189}
190