1package sftp
2
3import (
4	"context"
5	"fmt"
6	"io"
7	"net"
8	"os"
9	"testing"
10
11	"github.com/stretchr/testify/assert"
12)
13
14var _ = fmt.Print
15
16type csPair struct {
17	cli *Client
18	svr *RequestServer
19}
20
21// these must be closed in order, else client.Close will hang
22func (cs csPair) Close() {
23	cs.svr.Close()
24	cs.cli.Close()
25	os.Remove(sock)
26}
27
28func (cs csPair) testHandler() *root {
29	return cs.svr.Handlers.FileGet.(*root)
30}
31
32const sock = "/tmp/rstest.sock"
33
34func clientRequestServerPair(t *testing.T) *csPair {
35	skipIfWindows(t)
36	ready := make(chan bool)
37	os.Remove(sock) // either this or signal handling
38	var server *RequestServer
39	go func() {
40		l, err := net.Listen("unix", sock)
41		if err != nil {
42			// neither assert nor t.Fatal reliably exit before Accept errors
43			panic(err)
44		}
45		ready <- true
46		fd, err := l.Accept()
47		assert.Nil(t, err)
48		handlers := InMemHandler()
49		server = NewRequestServer(fd, handlers)
50		server.Serve()
51	}()
52	<-ready
53	defer os.Remove(sock)
54	c, err := net.Dial("unix", sock)
55	assert.Nil(t, err)
56	client, err := NewClientPipe(c, c)
57	if err != nil {
58		t.Fatalf("%+v\n", err)
59	}
60	return &csPair{client, server}
61}
62
63// after adding logging, maybe check log to make sure packet handling
64// was split over more than one worker
65func TestRequestSplitWrite(t *testing.T) {
66	p := clientRequestServerPair(t)
67	defer p.Close()
68	w, err := p.cli.Create("/foo")
69	assert.Nil(t, err)
70	p.cli.maxPacket = 3 // force it to send in small chunks
71	contents := "one two three four five six seven eight nine ten"
72	w.Write([]byte(contents))
73	w.Close()
74	r := p.testHandler()
75	f, _ := r.fetch("/foo")
76	assert.Equal(t, contents, string(f.content))
77}
78
79func TestRequestCache(t *testing.T) {
80	p := clientRequestServerPair(t)
81	defer p.Close()
82	foo := NewRequest("", "foo")
83	foo.ctx, foo.cancelCtx = context.WithCancel(context.Background())
84	bar := NewRequest("", "bar")
85	fh := p.svr.nextRequest(foo)
86	bh := p.svr.nextRequest(bar)
87	assert.Len(t, p.svr.openRequests, 2)
88	_foo, ok := p.svr.getRequest(fh)
89	assert.Equal(t, foo.Method, _foo.Method)
90	assert.Equal(t, foo.Filepath, _foo.Filepath)
91	assert.Equal(t, foo.Target, _foo.Target)
92	assert.Equal(t, foo.Flags, _foo.Flags)
93	assert.Equal(t, foo.Attrs, _foo.Attrs)
94	assert.Equal(t, foo.state, _foo.state)
95	assert.NotNil(t, _foo.ctx)
96	assert.Equal(t, _foo.Context().Err(), nil, "context is still valid")
97	assert.True(t, ok)
98	_, ok = p.svr.getRequest("zed")
99	assert.False(t, ok)
100	p.svr.closeRequest(fh)
101	assert.Equal(t, _foo.Context().Err(), context.Canceled, "context is now canceled")
102	p.svr.closeRequest(bh)
103	assert.Len(t, p.svr.openRequests, 0)
104}
105
106func TestRequestCacheState(t *testing.T) {
107	// test operation that uses open/close
108	p := clientRequestServerPair(t)
109	defer p.Close()
110	_, err := putTestFile(p.cli, "/foo", "hello")
111	assert.Nil(t, err)
112	assert.Len(t, p.svr.openRequests, 0)
113	// test operation that doesn't open/close
114	err = p.cli.Remove("/foo")
115	assert.Nil(t, err)
116	assert.Len(t, p.svr.openRequests, 0)
117}
118
119func putTestFile(cli *Client, path, content string) (int, error) {
120	w, err := cli.Create(path)
121	if err == nil {
122		defer w.Close()
123		return w.Write([]byte(content))
124	}
125	return 0, err
126}
127
128func TestRequestWrite(t *testing.T) {
129	p := clientRequestServerPair(t)
130	defer p.Close()
131	n, err := putTestFile(p.cli, "/foo", "hello")
132	assert.Nil(t, err)
133	assert.Equal(t, 5, n)
134	r := p.testHandler()
135	f, err := r.fetch("/foo")
136	assert.Nil(t, err)
137	assert.False(t, f.isdir)
138	assert.Equal(t, f.content, []byte("hello"))
139}
140
141func TestRequestWriteEmpty(t *testing.T) {
142	p := clientRequestServerPair(t)
143	defer p.Close()
144	n, err := putTestFile(p.cli, "/foo", "")
145	assert.NoError(t, err)
146	assert.Equal(t, 0, n)
147	r := p.testHandler()
148	f, err := r.fetch("/foo")
149	if assert.Nil(t, err) {
150		assert.False(t, f.isdir)
151		assert.Len(t, f.content, 0)
152	}
153	// lets test with an error
154	r.returnErr(os.ErrInvalid)
155	n, err = putTestFile(p.cli, "/bar", "")
156	assert.Error(t, err)
157	r.returnErr(nil)
158	assert.Equal(t, 0, n)
159}
160
161func TestRequestFilename(t *testing.T) {
162	p := clientRequestServerPair(t)
163	defer p.Close()
164	_, err := putTestFile(p.cli, "/foo", "hello")
165	assert.NoError(t, err)
166	r := p.testHandler()
167	f, err := r.fetch("/foo")
168	assert.NoError(t, err)
169	assert.Equal(t, f.Name(), "foo")
170	_, err = r.fetch("/bar")
171	assert.Error(t, err)
172}
173
174func TestRequestJustRead(t *testing.T) {
175	p := clientRequestServerPair(t)
176	defer p.Close()
177	_, err := putTestFile(p.cli, "/foo", "hello")
178	assert.Nil(t, err)
179	rf, err := p.cli.Open("/foo")
180	assert.Nil(t, err)
181	defer rf.Close()
182	contents := make([]byte, 5)
183	n, err := rf.Read(contents)
184	if err != nil && err != io.EOF {
185		t.Fatalf("err: %v", err)
186	}
187	assert.Equal(t, 5, n)
188	assert.Equal(t, "hello", string(contents[0:5]))
189}
190
191func TestRequestOpenFail(t *testing.T) {
192	p := clientRequestServerPair(t)
193	defer p.Close()
194	rf, err := p.cli.Open("/foo")
195	assert.Exactly(t, os.ErrNotExist, err)
196	assert.Nil(t, rf)
197}
198
199func TestRequestCreate(t *testing.T) {
200	p := clientRequestServerPair(t)
201	defer p.Close()
202	fh, err := p.cli.Create("foo")
203	assert.Nil(t, err)
204	err = fh.Close()
205	assert.Nil(t, err)
206}
207
208func TestRequestMkdir(t *testing.T) {
209	p := clientRequestServerPair(t)
210	defer p.Close()
211	err := p.cli.Mkdir("/foo")
212	assert.Nil(t, err)
213	r := p.testHandler()
214	f, err := r.fetch("/foo")
215	assert.Nil(t, err)
216	assert.True(t, f.isdir)
217}
218
219func TestRequestRemove(t *testing.T) {
220	p := clientRequestServerPair(t)
221	defer p.Close()
222	_, err := putTestFile(p.cli, "/foo", "hello")
223	assert.Nil(t, err)
224	r := p.testHandler()
225	_, err = r.fetch("/foo")
226	assert.Nil(t, err)
227	err = p.cli.Remove("/foo")
228	assert.Nil(t, err)
229	_, err = r.fetch("/foo")
230	assert.Equal(t, err, os.ErrNotExist)
231}
232
233func TestRequestRename(t *testing.T) {
234	p := clientRequestServerPair(t)
235	defer p.Close()
236	_, err := putTestFile(p.cli, "/foo", "hello")
237	assert.Nil(t, err)
238	r := p.testHandler()
239	_, err = r.fetch("/foo")
240	assert.Nil(t, err)
241	err = p.cli.Rename("/foo", "/bar")
242	assert.Nil(t, err)
243	f, err := r.fetch("/bar")
244	assert.Equal(t, "bar", f.Name())
245	assert.Nil(t, err)
246	_, err = r.fetch("/foo")
247	assert.Equal(t, os.ErrNotExist, err)
248}
249
250func TestRequestRenameFail(t *testing.T) {
251	p := clientRequestServerPair(t)
252	defer p.Close()
253	_, err := putTestFile(p.cli, "/foo", "hello")
254	assert.Nil(t, err)
255	_, err = putTestFile(p.cli, "/bar", "goodbye")
256	assert.Nil(t, err)
257	err = p.cli.Rename("/foo", "/bar")
258	assert.IsType(t, &StatusError{}, err)
259}
260
261func TestRequestStat(t *testing.T) {
262	p := clientRequestServerPair(t)
263	defer p.Close()
264	_, err := putTestFile(p.cli, "/foo", "hello")
265	assert.Nil(t, err)
266	fi, err := p.cli.Stat("/foo")
267	assert.Equal(t, fi.Name(), "foo")
268	assert.Equal(t, fi.Size(), int64(5))
269	assert.Equal(t, fi.Mode(), os.FileMode(0644))
270	assert.NoError(t, testOsSys(fi.Sys()))
271	assert.NoError(t, err)
272}
273
274// NOTE: Setstat is a noop in the request server tests, but we want to test
275// that is does nothing without crapping out.
276func TestRequestSetstat(t *testing.T) {
277	p := clientRequestServerPair(t)
278	defer p.Close()
279	_, err := putTestFile(p.cli, "/foo", "hello")
280	assert.Nil(t, err)
281	mode := os.FileMode(0644)
282	err = p.cli.Chmod("/foo", mode)
283	assert.Nil(t, err)
284	fi, err := p.cli.Stat("/foo")
285	assert.Nil(t, err)
286	assert.Equal(t, fi.Name(), "foo")
287	assert.Equal(t, fi.Size(), int64(5))
288	assert.Equal(t, fi.Mode(), os.FileMode(0644))
289	assert.NoError(t, testOsSys(fi.Sys()))
290}
291
292func TestRequestFstat(t *testing.T) {
293	p := clientRequestServerPair(t)
294	defer p.Close()
295	_, err := putTestFile(p.cli, "/foo", "hello")
296	assert.Nil(t, err)
297	fp, err := p.cli.Open("/foo")
298	assert.Nil(t, err)
299	fi, err := fp.Stat()
300	if assert.NoError(t, err) {
301		assert.Equal(t, fi.Name(), "foo")
302		assert.Equal(t, fi.Size(), int64(5))
303		assert.Equal(t, fi.Mode(), os.FileMode(0644))
304		assert.NoError(t, testOsSys(fi.Sys()))
305	}
306}
307
308func TestRequestStatFail(t *testing.T) {
309	p := clientRequestServerPair(t)
310	defer p.Close()
311	fi, err := p.cli.Stat("/foo")
312	assert.Nil(t, fi)
313	assert.True(t, os.IsNotExist(err))
314}
315
316func TestRequestSymlink(t *testing.T) {
317	p := clientRequestServerPair(t)
318	defer p.Close()
319	_, err := putTestFile(p.cli, "/foo", "hello")
320	assert.Nil(t, err)
321	err = p.cli.Symlink("/foo", "/bar")
322	assert.Nil(t, err)
323	r := p.testHandler()
324	fi, err := r.fetch("/bar")
325	assert.Nil(t, err)
326	assert.True(t, fi.Mode()&os.ModeSymlink == os.ModeSymlink)
327}
328
329func TestRequestSymlinkFail(t *testing.T) {
330	p := clientRequestServerPair(t)
331	defer p.Close()
332	err := p.cli.Symlink("/foo", "/bar")
333	assert.True(t, os.IsNotExist(err))
334}
335
336func TestRequestReadlink(t *testing.T) {
337	p := clientRequestServerPair(t)
338	defer p.Close()
339	_, err := putTestFile(p.cli, "/foo", "hello")
340	assert.Nil(t, err)
341	err = p.cli.Symlink("/foo", "/bar")
342	assert.Nil(t, err)
343	rl, err := p.cli.ReadLink("/bar")
344	assert.Nil(t, err)
345	assert.Equal(t, "foo", rl)
346}
347
348func TestRequestReaddir(t *testing.T) {
349	p := clientRequestServerPair(t)
350	MaxFilelist = 22 // make not divisible by our test amount (100)
351	defer p.Close()
352	for i := 0; i < 100; i++ {
353		fname := fmt.Sprintf("/foo_%02d", i)
354		_, err := putTestFile(p.cli, fname, fname)
355		if err != nil {
356			t.Fatal("expected no error, got:", err)
357		}
358	}
359	_, err := p.cli.ReadDir("/foo_01")
360	assert.Equal(t, &StatusError{Code: ssh_FX_FAILURE,
361		msg: " /foo_01: not a directory"}, err)
362	_, err = p.cli.ReadDir("/does_not_exist")
363	assert.Equal(t, os.ErrNotExist, err)
364	di, err := p.cli.ReadDir("/")
365	assert.Nil(t, err)
366	assert.Len(t, di, 100)
367	names := []string{di[18].Name(), di[81].Name()}
368	assert.Equal(t, []string{"foo_18", "foo_81"}, names)
369}
370
371func TestCleanPath(t *testing.T) {
372	assert.Equal(t, "/", cleanPath("/"))
373	assert.Equal(t, "/", cleanPath("."))
374	assert.Equal(t, "/", cleanPath("/."))
375	assert.Equal(t, "/", cleanPath("/a/.."))
376	assert.Equal(t, "/a/c", cleanPath("/a/b/../c"))
377	assert.Equal(t, "/a/c", cleanPath("/a/b/../c/"))
378	assert.Equal(t, "/a", cleanPath("/a/b/.."))
379	assert.Equal(t, "/a/b/c", cleanPath("/a/b/c"))
380	assert.Equal(t, "/", cleanPath("//"))
381	assert.Equal(t, "/a", cleanPath("/a/"))
382	assert.Equal(t, "/a", cleanPath("a/"))
383	assert.Equal(t, "/a/b/c", cleanPath("/a//b//c/"))
384
385	// filepath.ToSlash does not touch \ as char on unix systems
386	// so os.PathSeparator is used for windows compatible tests
387	bslash := string(os.PathSeparator)
388	assert.Equal(t, "/", cleanPath(bslash))
389	assert.Equal(t, "/", cleanPath(bslash+bslash))
390	assert.Equal(t, "/a", cleanPath(bslash+"a"+bslash))
391	assert.Equal(t, "/a", cleanPath("a"+bslash))
392	assert.Equal(t, "/a/b/c",
393		cleanPath(bslash+"a"+bslash+bslash+"b"+bslash+bslash+"c"+bslash))
394}
395