1package sftp
2
3import (
4	"context"
5	"io"
6	"os"
7	"path"
8	"path/filepath"
9	"sync"
10	"syscall"
11
12	"github.com/pkg/errors"
13)
14
15// MaxFilelist is the max number of files to return in a readdir batch.
16var MaxFilelist int64 = 100
17
18// Request contains the data and state for the incoming service request.
19type Request struct {
20	// Get, Put, Setstat, Stat, Rename, Remove
21	// Rmdir, Mkdir, List, Readlink, Symlink
22	Method   string
23	Filepath string
24	Flags    uint32
25	Attrs    []byte // convert to sub-struct
26	Target   string // for renames and sym-links
27	handle   string
28	// reader/writer/readdir from handlers
29	state state
30	// context lasts duration of request
31	ctx       context.Context
32	cancelCtx context.CancelFunc
33}
34
35type state struct {
36	*sync.RWMutex
37	writerAt io.WriterAt
38	readerAt io.ReaderAt
39	listerAt ListerAt
40	lsoffset int64
41}
42
43// New Request initialized based on packet data
44func requestFromPacket(ctx context.Context, pkt hasPath) *Request {
45	method := requestMethod(pkt)
46	request := NewRequest(method, pkt.getPath())
47	request.ctx, request.cancelCtx = context.WithCancel(ctx)
48
49	switch p := pkt.(type) {
50	case *sshFxpOpenPacket:
51		request.Flags = p.Pflags
52	case *sshFxpSetstatPacket:
53		request.Flags = p.Flags
54		request.Attrs = p.Attrs.([]byte)
55	case *sshFxpRenamePacket:
56		request.Target = cleanPath(p.Newpath)
57	case *sshFxpSymlinkPacket:
58		request.Target = cleanPath(p.Linkpath)
59	}
60	return request
61}
62
63// NewRequest creates a new Request object.
64func NewRequest(method, path string) *Request {
65	return &Request{Method: method, Filepath: cleanPath(path),
66		state: state{RWMutex: new(sync.RWMutex)}}
67}
68
69// shallow copy of existing request
70func (r *Request) copy() *Request {
71	r.state.Lock()
72	defer r.state.Unlock()
73	r2 := new(Request)
74	*r2 = *r
75	return r2
76}
77
78// Context returns the request's context. To change the context,
79// use WithContext.
80//
81// The returned context is always non-nil; it defaults to the
82// background context.
83//
84// For incoming server requests, the context is canceled when the
85// request is complete or the client's connection closes.
86func (r *Request) Context() context.Context {
87	if r.ctx != nil {
88		return r.ctx
89	}
90	return context.Background()
91}
92
93// WithContext returns a copy of r with its context changed to ctx.
94// The provided ctx must be non-nil.
95func (r *Request) WithContext(ctx context.Context) *Request {
96	if ctx == nil {
97		panic("nil context")
98	}
99	r2 := r.copy()
100	r2.ctx = ctx
101	r2.cancelCtx = nil
102	return r2
103}
104
105// Returns current offset for file list
106func (r *Request) lsNext() int64 {
107	r.state.RLock()
108	defer r.state.RUnlock()
109	return r.state.lsoffset
110}
111
112// Increases next offset
113func (r *Request) lsInc(offset int64) {
114	r.state.Lock()
115	defer r.state.Unlock()
116	r.state.lsoffset = r.state.lsoffset + offset
117}
118
119// manage file read/write state
120func (r *Request) setListerState(la ListerAt) {
121	r.state.Lock()
122	defer r.state.Unlock()
123	r.state.listerAt = la
124}
125
126func (r *Request) getLister() ListerAt {
127	r.state.RLock()
128	defer r.state.RUnlock()
129	return r.state.listerAt
130}
131
132// Close reader/writer if possible
133func (r *Request) close() error {
134	defer func() {
135		if r.cancelCtx != nil {
136			r.cancelCtx()
137		}
138	}()
139	r.state.RLock()
140	rd := r.state.readerAt
141	r.state.RUnlock()
142	if c, ok := rd.(io.Closer); ok {
143		return c.Close()
144	}
145	r.state.RLock()
146	wt := r.state.writerAt
147	r.state.RUnlock()
148	if c, ok := wt.(io.Closer); ok {
149		return c.Close()
150	}
151	return nil
152}
153
154// called from worker to handle packet/request
155func (r *Request) call(handlers Handlers, pkt requestPacket) responsePacket {
156	switch r.Method {
157	case "Get":
158		return fileget(handlers.FileGet, r, pkt)
159	case "Put":
160		return fileput(handlers.FilePut, r, pkt)
161	case "Setstat", "Rename", "Rmdir", "Mkdir", "Symlink", "Remove":
162		return filecmd(handlers.FileCmd, r, pkt)
163	case "List":
164		return filelist(handlers.FileList, r, pkt)
165	case "Stat", "Readlink":
166		return filestat(handlers.FileList, r, pkt)
167	default:
168		return statusFromError(pkt,
169			errors.Errorf("unexpected method: %s", r.Method))
170	}
171}
172
173// Additional initialization for Open packets
174func (r *Request) open(h Handlers, pkt requestPacket) responsePacket {
175	flags := r.Pflags()
176	var err error
177	switch {
178	case flags.Write, flags.Append, flags.Creat, flags.Trunc:
179		r.Method = "Put"
180		r.state.writerAt, err = h.FilePut.Filewrite(r)
181	case flags.Read:
182		r.Method = "Get"
183		r.state.readerAt, err = h.FileGet.Fileread(r)
184	default:
185		return statusFromError(pkt, errors.New("bad file flags"))
186	}
187	if err != nil {
188		return statusFromError(pkt, err)
189	}
190	return &sshFxpHandlePacket{ID: pkt.id(), Handle: r.handle}
191}
192func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
193	var err error
194	r.Method = "List"
195	r.state.listerAt, err = h.FileList.Filelist(r)
196	if err != nil {
197		switch err.(type) {
198		case syscall.Errno:
199			err = &os.PathError{Path: r.Filepath, Err: err}
200		}
201		return statusFromError(pkt, err)
202	}
203	return &sshFxpHandlePacket{ID: pkt.id(), Handle: r.handle}
204}
205
206// wrap FileReader handler
207func fileget(h FileReader, r *Request, pkt requestPacket) responsePacket {
208	//fmt.Println("fileget", r)
209	r.state.RLock()
210	reader := r.state.readerAt
211	r.state.RUnlock()
212	if reader == nil {
213		return statusFromError(pkt, errors.New("unexpected read packet"))
214	}
215
216	_, offset, length := packetData(pkt)
217	data := make([]byte, clamp(length, maxTxPacket))
218	n, err := reader.ReadAt(data, offset)
219	// only return EOF erro if no data left to read
220	if err != nil && (err != io.EOF || n == 0) {
221		return statusFromError(pkt, err)
222	}
223	return &sshFxpDataPacket{
224		ID:     pkt.id(),
225		Length: uint32(n),
226		Data:   data[:n],
227	}
228}
229
230// wrap FileWriter handler
231func fileput(h FileWriter, r *Request, pkt requestPacket) responsePacket {
232	//fmt.Println("fileput", r)
233	r.state.RLock()
234	writer := r.state.writerAt
235	r.state.RUnlock()
236	if writer == nil {
237		return statusFromError(pkt, errors.New("unexpected write packet"))
238	}
239
240	data, offset, _ := packetData(pkt)
241	_, err := writer.WriteAt(data, offset)
242	return statusFromError(pkt, err)
243}
244
245// file data for additional read/write packets
246func packetData(p requestPacket) (data []byte, offset int64, length uint32) {
247	switch p := p.(type) {
248	case *sshFxpReadPacket:
249		length = p.Len
250		offset = int64(p.Offset)
251	case *sshFxpWritePacket:
252		data = p.Data
253		length = p.Length
254		offset = int64(p.Offset)
255	}
256	return
257}
258
259// wrap FileCmder handler
260func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket {
261
262	switch p := pkt.(type) {
263	case *sshFxpFsetstatPacket:
264		r.Flags = p.Flags
265		r.Attrs = p.Attrs.([]byte)
266	}
267	err := h.Filecmd(r)
268	return statusFromError(pkt, err)
269}
270
271// wrap FileLister handler
272func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket {
273	var err error
274	lister := r.getLister()
275	if lister == nil {
276		return statusFromError(pkt, errors.New("unexpected dir packet"))
277	}
278
279	offset := r.lsNext()
280	finfo := make([]os.FileInfo, MaxFilelist)
281	n, err := lister.ListAt(finfo, offset)
282	r.lsInc(int64(n))
283	// ignore EOF as we only return it when there are no results
284	finfo = finfo[:n] // avoid need for nil tests below
285
286	switch r.Method {
287	case "List":
288		if err != nil && err != io.EOF {
289			return statusFromError(pkt, err)
290		}
291		if err == io.EOF && n == 0 {
292			return statusFromError(pkt, io.EOF)
293		}
294		dirname := filepath.ToSlash(path.Base(r.Filepath))
295		ret := &sshFxpNamePacket{ID: pkt.id()}
296
297		for _, fi := range finfo {
298			ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{
299				Name:     fi.Name(),
300				LongName: runLs(dirname, fi),
301				Attrs:    []interface{}{fi},
302			})
303		}
304		return ret
305	default:
306		err = errors.Errorf("unexpected method: %s", r.Method)
307		return statusFromError(pkt, err)
308	}
309}
310
311func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket {
312	lister, err := h.Filelist(r)
313	if err != nil {
314		return statusFromError(pkt, err)
315	}
316	finfo := make([]os.FileInfo, 1)
317	n, err := lister.ListAt(finfo, 0)
318	finfo = finfo[:n] // avoid need for nil tests below
319
320	switch r.Method {
321	case "Stat":
322		if err != nil && err != io.EOF {
323			return statusFromError(pkt, err)
324		}
325		if n == 0 {
326			err = &os.PathError{Op: "stat", Path: r.Filepath,
327				Err: syscall.ENOENT}
328			return statusFromError(pkt, err)
329		}
330		return &sshFxpStatResponse{
331			ID:   pkt.id(),
332			info: finfo[0],
333		}
334	case "Readlink":
335		if err != nil && err != io.EOF {
336			return statusFromError(pkt, err)
337		}
338		if n == 0 {
339			err = &os.PathError{Op: "readlink", Path: r.Filepath,
340				Err: syscall.ENOENT}
341			return statusFromError(pkt, err)
342		}
343		filename := finfo[0].Name()
344		return &sshFxpNamePacket{
345			ID: pkt.id(),
346			NameAttrs: []sshFxpNameAttr{{
347				Name:     filename,
348				LongName: filename,
349				Attrs:    emptyFileStat,
350			}},
351		}
352	default:
353		err = errors.Errorf("unexpected method: %s", r.Method)
354		return statusFromError(pkt, err)
355	}
356}
357
358// init attributes of request object from packet data
359func requestMethod(p requestPacket) (method string) {
360	switch p.(type) {
361	case *sshFxpReadPacket, *sshFxpWritePacket, *sshFxpOpenPacket:
362		// set in open() above
363	case *sshFxpOpendirPacket, *sshFxpReaddirPacket:
364		// set in opendir() above
365	case *sshFxpSetstatPacket, *sshFxpFsetstatPacket:
366		method = "Setstat"
367	case *sshFxpRenamePacket:
368		method = "Rename"
369	case *sshFxpSymlinkPacket:
370		method = "Symlink"
371	case *sshFxpRemovePacket:
372		method = "Remove"
373	case *sshFxpStatPacket, *sshFxpLstatPacket, *sshFxpFstatPacket:
374		method = "Stat"
375	case *sshFxpRmdirPacket:
376		method = "Rmdir"
377	case *sshFxpReadlinkPacket:
378		method = "Readlink"
379	case *sshFxpMkdirPacket:
380		method = "Mkdir"
381	}
382	return method
383}
384