1package sftp
2
3import (
4	"context"
5	"encoding"
6	"io"
7	"path"
8	"path/filepath"
9	"strconv"
10	"sync"
11	"syscall"
12
13	"github.com/pkg/errors"
14)
15
16var maxTxPacket uint32 = 1 << 15
17
18// Handlers contains the 4 SFTP server request handlers.
19type Handlers struct {
20	FileGet  FileReader
21	FilePut  FileWriter
22	FileCmd  FileCmder
23	FileList FileLister
24}
25
26// RequestServer abstracts the sftp protocol with an http request-like protocol
27type RequestServer struct {
28	*serverConn
29	Handlers        Handlers
30	pktMgr          *packetManager
31	openRequests    map[string]*Request
32	openRequestLock sync.RWMutex
33	handleCount     int
34}
35
36// NewRequestServer creates/allocates/returns new RequestServer.
37// Normally there there will be one server per user-session.
38func NewRequestServer(rwc io.ReadWriteCloser, h Handlers) *RequestServer {
39	svrConn := &serverConn{
40		conn: conn{
41			Reader:      rwc,
42			WriteCloser: rwc,
43		},
44	}
45	return &RequestServer{
46		serverConn:   svrConn,
47		Handlers:     h,
48		pktMgr:       newPktMgr(svrConn),
49		openRequests: make(map[string]*Request),
50	}
51}
52
53// New Open packet/Request
54func (rs *RequestServer) nextRequest(r *Request) string {
55	rs.openRequestLock.Lock()
56	defer rs.openRequestLock.Unlock()
57	rs.handleCount++
58	handle := strconv.Itoa(rs.handleCount)
59	rs.openRequests[handle] = r
60	return handle
61}
62
63// Returns Request from openRequests, bool is false if it is missing
64// If the method is different, save/return a new Request w/ that Method.
65//
66// The Requests in openRequests work essentially as open file descriptors that
67// you can do different things with. What you are doing with it are denoted by
68// the first packet of that type (read/write/etc). We create a new Request when
69// it changes to set the request.Method attribute in a thread safe way.
70func (rs *RequestServer) getRequest(handle, method string) (*Request, bool) {
71	rs.openRequestLock.RLock()
72	r, ok := rs.openRequests[handle]
73	rs.openRequestLock.RUnlock()
74	if !ok || r.Method == method {
75		return r, ok
76	}
77	// if we make it here we need to replace the request
78	rs.openRequestLock.Lock()
79	defer rs.openRequestLock.Unlock()
80	r, ok = rs.openRequests[handle]
81	if !ok || r.Method == method { // re-check needed b/c lock race
82		return r, ok
83	}
84	r = r.copy()
85	r.Method = method
86	rs.openRequests[handle] = r
87	return r, ok
88}
89
90func (rs *RequestServer) closeRequest(handle string) error {
91	rs.openRequestLock.Lock()
92	defer rs.openRequestLock.Unlock()
93	if r, ok := rs.openRequests[handle]; ok {
94		delete(rs.openRequests, handle)
95		return r.close()
96	}
97	return syscall.EBADF
98}
99
100// Close the read/write/closer to trigger exiting the main server loop
101func (rs *RequestServer) Close() error { return rs.conn.Close() }
102
103// Serve requests for user session
104func (rs *RequestServer) Serve() error {
105	ctx, cancel := context.WithCancel(context.Background())
106	defer cancel()
107	var wg sync.WaitGroup
108	runWorker := func(ch requestChan) {
109		wg.Add(1)
110		go func() {
111			defer wg.Done()
112			if err := rs.packetWorker(ctx, ch); err != nil {
113				rs.conn.Close() // shuts down recvPacket
114			}
115		}()
116	}
117	pktChan := rs.pktMgr.workerChan(runWorker)
118
119	var err error
120	var pkt requestPacket
121	var pktType uint8
122	var pktBytes []byte
123	for {
124		pktType, pktBytes, err = rs.recvPacket()
125		if err != nil {
126			break
127		}
128
129		pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
130		if err != nil {
131			switch errors.Cause(err) {
132			case errUnknownExtendedPacket:
133				if err := rs.serverConn.sendError(pkt, ErrSshFxOpUnsupported); err != nil {
134					debug("failed to send err packet: %v", err)
135					rs.conn.Close() // shuts down recvPacket
136					break
137				}
138			default:
139				debug("makePacket err: %v", err)
140				rs.conn.Close() // shuts down recvPacket
141				break
142			}
143		}
144
145		pktChan <- pkt
146	}
147
148	close(pktChan) // shuts down sftpServerWorkers
149	wg.Wait()      // wait for all workers to exit
150
151	// make sure all open requests are properly closed
152	// (eg. possible on dropped connections, client crashes, etc.)
153	for handle, req := range rs.openRequests {
154		delete(rs.openRequests, handle)
155		req.close()
156	}
157
158	return err
159}
160
161func (rs *RequestServer) packetWorker(
162	ctx context.Context, pktChan chan requestPacket,
163) error {
164	for pkt := range pktChan {
165		var rpkt responsePacket
166		switch pkt := pkt.(type) {
167		case *sshFxInitPacket:
168			rpkt = sshFxVersionPacket{sftpProtocolVersion, nil}
169		case *sshFxpClosePacket:
170			handle := pkt.getHandle()
171			rpkt = statusFromError(pkt, rs.closeRequest(handle))
172		case *sshFxpRealpathPacket:
173			rpkt = cleanPacketPath(pkt)
174		case *sshFxpOpendirPacket:
175			request := requestFromPacket(ctx, pkt)
176			handle := rs.nextRequest(request)
177			rpkt = sshFxpHandlePacket{pkt.id(), handle}
178		case *sshFxpOpenPacket:
179			request := requestFromPacket(ctx, pkt)
180			handle := rs.nextRequest(request)
181			rpkt = sshFxpHandlePacket{pkt.id(), handle}
182			if pkt.hasPflags(ssh_FXF_CREAT) {
183				if p := request.call(rs.Handlers, pkt); !statusOk(p) {
184					rpkt = p // if error in write, return it
185				}
186			}
187		case hasHandle:
188			handle := pkt.getHandle()
189			request, ok := rs.getRequest(handle, requestMethod(pkt))
190			if !ok {
191				rpkt = statusFromError(pkt, syscall.EBADF)
192			} else {
193				rpkt = request.call(rs.Handlers, pkt)
194			}
195		case hasPath:
196			request := requestFromPacket(ctx, pkt)
197			rpkt = request.call(rs.Handlers, pkt)
198			request.close()
199		default:
200			return errors.Errorf("unexpected packet type %T", pkt)
201		}
202
203		err := rs.sendPacket(rpkt)
204		if err != nil {
205			return err
206		}
207	}
208	return nil
209}
210
211// True is responsePacket is an OK status packet
212func statusOk(rpkt responsePacket) bool {
213	p, ok := rpkt.(sshFxpStatusPacket)
214	return ok && p.StatusError.Code == ssh_FX_OK
215}
216
217// clean and return name packet for file
218func cleanPacketPath(pkt *sshFxpRealpathPacket) responsePacket {
219	path := cleanPath(pkt.getPath())
220	return &sshFxpNamePacket{
221		ID: pkt.id(),
222		NameAttrs: []sshFxpNameAttr{{
223			Name:     path,
224			LongName: path,
225			Attrs:    emptyFileStat,
226		}},
227	}
228}
229
230// Makes sure we have a clean POSIX (/) absolute path to work with
231func cleanPath(p string) string {
232	p = filepath.ToSlash(p)
233	if !filepath.IsAbs(p) {
234		p = "/" + p
235	}
236	return path.Clean(p)
237}
238
239// Wrap underlying connection methods to use packetManager
240func (rs *RequestServer) sendPacket(m encoding.BinaryMarshaler) error {
241	if pkt, ok := m.(responsePacket); ok {
242		rs.pktMgr.readyPacket(pkt)
243	} else {
244		return errors.Errorf("unexpected packet type %T", m)
245	}
246	return nil
247}
248