1package sftp 2 3import ( 4 "context" 5 "fmt" 6 "io" 7 "net" 8 "os" 9 "testing" 10 11 "github.com/stretchr/testify/assert" 12) 13 14var _ = fmt.Print 15 16type csPair struct { 17 cli *Client 18 svr *RequestServer 19} 20 21// these must be closed in order, else client.Close will hang 22func (cs csPair) Close() { 23 cs.svr.Close() 24 cs.cli.Close() 25 os.Remove(sock) 26} 27 28func (cs csPair) testHandler() *root { 29 return cs.svr.Handlers.FileGet.(*root) 30} 31 32const sock = "/tmp/rstest.sock" 33 34func clientRequestServerPair(t *testing.T) *csPair { 35 skipIfWindows(t) 36 ready := make(chan bool) 37 os.Remove(sock) // either this or signal handling 38 var server *RequestServer 39 go func() { 40 l, err := net.Listen("unix", sock) 41 if err != nil { 42 // neither assert nor t.Fatal reliably exit before Accept errors 43 panic(err) 44 } 45 ready <- true 46 fd, err := l.Accept() 47 assert.Nil(t, err) 48 handlers := InMemHandler() 49 server = NewRequestServer(fd, handlers) 50 server.Serve() 51 }() 52 <-ready 53 defer os.Remove(sock) 54 c, err := net.Dial("unix", sock) 55 assert.Nil(t, err) 56 client, err := NewClientPipe(c, c) 57 if err != nil { 58 t.Fatalf("%+v\n", err) 59 } 60 return &csPair{client, server} 61} 62 63// after adding logging, maybe check log to make sure packet handling 64// was split over more than one worker 65func TestRequestSplitWrite(t *testing.T) { 66 p := clientRequestServerPair(t) 67 defer p.Close() 68 w, err := p.cli.Create("/foo") 69 assert.Nil(t, err) 70 p.cli.maxPacket = 3 // force it to send in small chunks 71 contents := "one two three four five six seven eight nine ten" 72 w.Write([]byte(contents)) 73 w.Close() 74 r := p.testHandler() 75 f, _ := r.fetch("/foo") 76 assert.Equal(t, contents, string(f.content)) 77} 78 79func TestRequestCache(t *testing.T) { 80 p := clientRequestServerPair(t) 81 defer p.Close() 82 foo := NewRequest("", "foo") 83 foo.ctx, foo.cancelCtx = context.WithCancel(context.Background()) 84 bar := NewRequest("", "bar") 85 fh := p.svr.nextRequest(foo) 86 bh := p.svr.nextRequest(bar) 87 assert.Len(t, p.svr.openRequests, 2) 88 _foo, ok := p.svr.getRequest(fh) 89 assert.Equal(t, foo.Method, _foo.Method) 90 assert.Equal(t, foo.Filepath, _foo.Filepath) 91 assert.Equal(t, foo.Target, _foo.Target) 92 assert.Equal(t, foo.Flags, _foo.Flags) 93 assert.Equal(t, foo.Attrs, _foo.Attrs) 94 assert.Equal(t, foo.state, _foo.state) 95 assert.NotNil(t, _foo.ctx) 96 assert.Equal(t, _foo.Context().Err(), nil, "context is still valid") 97 assert.True(t, ok) 98 _, ok = p.svr.getRequest("zed") 99 assert.False(t, ok) 100 p.svr.closeRequest(fh) 101 assert.Equal(t, _foo.Context().Err(), context.Canceled, "context is now canceled") 102 p.svr.closeRequest(bh) 103 assert.Len(t, p.svr.openRequests, 0) 104} 105 106func TestRequestCacheState(t *testing.T) { 107 // test operation that uses open/close 108 p := clientRequestServerPair(t) 109 defer p.Close() 110 _, err := putTestFile(p.cli, "/foo", "hello") 111 assert.Nil(t, err) 112 assert.Len(t, p.svr.openRequests, 0) 113 // test operation that doesn't open/close 114 err = p.cli.Remove("/foo") 115 assert.Nil(t, err) 116 assert.Len(t, p.svr.openRequests, 0) 117} 118 119func putTestFile(cli *Client, path, content string) (int, error) { 120 w, err := cli.Create(path) 121 if err == nil { 122 defer w.Close() 123 return w.Write([]byte(content)) 124 } 125 return 0, err 126} 127 128func TestRequestWrite(t *testing.T) { 129 p := clientRequestServerPair(t) 130 defer p.Close() 131 n, err := putTestFile(p.cli, "/foo", "hello") 132 assert.Nil(t, err) 133 assert.Equal(t, 5, n) 134 r := p.testHandler() 135 f, err := r.fetch("/foo") 136 assert.Nil(t, err) 137 assert.False(t, f.isdir) 138 assert.Equal(t, f.content, []byte("hello")) 139} 140 141func TestRequestWriteEmpty(t *testing.T) { 142 p := clientRequestServerPair(t) 143 defer p.Close() 144 n, err := putTestFile(p.cli, "/foo", "") 145 assert.NoError(t, err) 146 assert.Equal(t, 0, n) 147 r := p.testHandler() 148 f, err := r.fetch("/foo") 149 if assert.Nil(t, err) { 150 assert.False(t, f.isdir) 151 assert.Len(t, f.content, 0) 152 } 153 // lets test with an error 154 r.returnErr(os.ErrInvalid) 155 n, err = putTestFile(p.cli, "/bar", "") 156 assert.Error(t, err) 157 r.returnErr(nil) 158 assert.Equal(t, 0, n) 159} 160 161func TestRequestFilename(t *testing.T) { 162 p := clientRequestServerPair(t) 163 defer p.Close() 164 _, err := putTestFile(p.cli, "/foo", "hello") 165 assert.NoError(t, err) 166 r := p.testHandler() 167 f, err := r.fetch("/foo") 168 assert.NoError(t, err) 169 assert.Equal(t, f.Name(), "foo") 170 _, err = r.fetch("/bar") 171 assert.Error(t, err) 172} 173 174func TestRequestJustRead(t *testing.T) { 175 p := clientRequestServerPair(t) 176 defer p.Close() 177 _, err := putTestFile(p.cli, "/foo", "hello") 178 assert.Nil(t, err) 179 rf, err := p.cli.Open("/foo") 180 assert.Nil(t, err) 181 defer rf.Close() 182 contents := make([]byte, 5) 183 n, err := rf.Read(contents) 184 if err != nil && err != io.EOF { 185 t.Fatalf("err: %v", err) 186 } 187 assert.Equal(t, 5, n) 188 assert.Equal(t, "hello", string(contents[0:5])) 189} 190 191func TestRequestOpenFail(t *testing.T) { 192 p := clientRequestServerPair(t) 193 defer p.Close() 194 rf, err := p.cli.Open("/foo") 195 assert.Exactly(t, os.ErrNotExist, err) 196 assert.Nil(t, rf) 197} 198 199func TestRequestCreate(t *testing.T) { 200 p := clientRequestServerPair(t) 201 defer p.Close() 202 fh, err := p.cli.Create("foo") 203 assert.Nil(t, err) 204 err = fh.Close() 205 assert.Nil(t, err) 206} 207 208func TestRequestMkdir(t *testing.T) { 209 p := clientRequestServerPair(t) 210 defer p.Close() 211 err := p.cli.Mkdir("/foo") 212 assert.Nil(t, err) 213 r := p.testHandler() 214 f, err := r.fetch("/foo") 215 assert.Nil(t, err) 216 assert.True(t, f.isdir) 217} 218 219func TestRequestRemove(t *testing.T) { 220 p := clientRequestServerPair(t) 221 defer p.Close() 222 _, err := putTestFile(p.cli, "/foo", "hello") 223 assert.Nil(t, err) 224 r := p.testHandler() 225 _, err = r.fetch("/foo") 226 assert.Nil(t, err) 227 err = p.cli.Remove("/foo") 228 assert.Nil(t, err) 229 _, err = r.fetch("/foo") 230 assert.Equal(t, err, os.ErrNotExist) 231} 232 233func TestRequestRename(t *testing.T) { 234 p := clientRequestServerPair(t) 235 defer p.Close() 236 _, err := putTestFile(p.cli, "/foo", "hello") 237 assert.Nil(t, err) 238 r := p.testHandler() 239 _, err = r.fetch("/foo") 240 assert.Nil(t, err) 241 err = p.cli.Rename("/foo", "/bar") 242 assert.Nil(t, err) 243 f, err := r.fetch("/bar") 244 assert.Equal(t, "bar", f.Name()) 245 assert.Nil(t, err) 246 _, err = r.fetch("/foo") 247 assert.Equal(t, os.ErrNotExist, err) 248} 249 250func TestRequestRenameFail(t *testing.T) { 251 p := clientRequestServerPair(t) 252 defer p.Close() 253 _, err := putTestFile(p.cli, "/foo", "hello") 254 assert.Nil(t, err) 255 _, err = putTestFile(p.cli, "/bar", "goodbye") 256 assert.Nil(t, err) 257 err = p.cli.Rename("/foo", "/bar") 258 assert.IsType(t, &StatusError{}, err) 259} 260 261func TestRequestStat(t *testing.T) { 262 p := clientRequestServerPair(t) 263 defer p.Close() 264 _, err := putTestFile(p.cli, "/foo", "hello") 265 assert.Nil(t, err) 266 fi, err := p.cli.Stat("/foo") 267 assert.Equal(t, fi.Name(), "foo") 268 assert.Equal(t, fi.Size(), int64(5)) 269 assert.Equal(t, fi.Mode(), os.FileMode(0644)) 270 assert.NoError(t, testOsSys(fi.Sys())) 271 assert.NoError(t, err) 272} 273 274// NOTE: Setstat is a noop in the request server tests, but we want to test 275// that is does nothing without crapping out. 276func TestRequestSetstat(t *testing.T) { 277 p := clientRequestServerPair(t) 278 defer p.Close() 279 _, err := putTestFile(p.cli, "/foo", "hello") 280 assert.Nil(t, err) 281 mode := os.FileMode(0644) 282 err = p.cli.Chmod("/foo", mode) 283 assert.Nil(t, err) 284 fi, err := p.cli.Stat("/foo") 285 assert.Nil(t, err) 286 assert.Equal(t, fi.Name(), "foo") 287 assert.Equal(t, fi.Size(), int64(5)) 288 assert.Equal(t, fi.Mode(), os.FileMode(0644)) 289 assert.NoError(t, testOsSys(fi.Sys())) 290} 291 292func TestRequestFstat(t *testing.T) { 293 p := clientRequestServerPair(t) 294 defer p.Close() 295 _, err := putTestFile(p.cli, "/foo", "hello") 296 assert.Nil(t, err) 297 fp, err := p.cli.Open("/foo") 298 assert.Nil(t, err) 299 fi, err := fp.Stat() 300 if assert.NoError(t, err) { 301 assert.Equal(t, fi.Name(), "foo") 302 assert.Equal(t, fi.Size(), int64(5)) 303 assert.Equal(t, fi.Mode(), os.FileMode(0644)) 304 assert.NoError(t, testOsSys(fi.Sys())) 305 } 306} 307 308func TestRequestStatFail(t *testing.T) { 309 p := clientRequestServerPair(t) 310 defer p.Close() 311 fi, err := p.cli.Stat("/foo") 312 assert.Nil(t, fi) 313 assert.True(t, os.IsNotExist(err)) 314} 315 316func TestRequestSymlink(t *testing.T) { 317 p := clientRequestServerPair(t) 318 defer p.Close() 319 _, err := putTestFile(p.cli, "/foo", "hello") 320 assert.Nil(t, err) 321 err = p.cli.Symlink("/foo", "/bar") 322 assert.Nil(t, err) 323 r := p.testHandler() 324 fi, err := r.fetch("/bar") 325 assert.Nil(t, err) 326 assert.True(t, fi.Mode()&os.ModeSymlink == os.ModeSymlink) 327} 328 329func TestRequestSymlinkFail(t *testing.T) { 330 p := clientRequestServerPair(t) 331 defer p.Close() 332 err := p.cli.Symlink("/foo", "/bar") 333 assert.True(t, os.IsNotExist(err)) 334} 335 336func TestRequestReadlink(t *testing.T) { 337 p := clientRequestServerPair(t) 338 defer p.Close() 339 _, err := putTestFile(p.cli, "/foo", "hello") 340 assert.Nil(t, err) 341 err = p.cli.Symlink("/foo", "/bar") 342 assert.Nil(t, err) 343 rl, err := p.cli.ReadLink("/bar") 344 assert.Nil(t, err) 345 assert.Equal(t, "foo", rl) 346} 347 348func TestRequestReaddir(t *testing.T) { 349 p := clientRequestServerPair(t) 350 MaxFilelist = 22 // make not divisible by our test amount (100) 351 defer p.Close() 352 for i := 0; i < 100; i++ { 353 fname := fmt.Sprintf("/foo_%02d", i) 354 _, err := putTestFile(p.cli, fname, fname) 355 if err != nil { 356 t.Fatal("expected no error, got:", err) 357 } 358 } 359 _, err := p.cli.ReadDir("/foo_01") 360 assert.Equal(t, &StatusError{Code: ssh_FX_FAILURE, 361 msg: " /foo_01: not a directory"}, err) 362 _, err = p.cli.ReadDir("/does_not_exist") 363 assert.Equal(t, os.ErrNotExist, err) 364 di, err := p.cli.ReadDir("/") 365 assert.Nil(t, err) 366 assert.Len(t, di, 100) 367 names := []string{di[18].Name(), di[81].Name()} 368 assert.Equal(t, []string{"foo_18", "foo_81"}, names) 369} 370 371func TestCleanPath(t *testing.T) { 372 assert.Equal(t, "/", cleanPath("/")) 373 assert.Equal(t, "/", cleanPath(".")) 374 assert.Equal(t, "/", cleanPath("/.")) 375 assert.Equal(t, "/", cleanPath("/a/..")) 376 assert.Equal(t, "/a/c", cleanPath("/a/b/../c")) 377 assert.Equal(t, "/a/c", cleanPath("/a/b/../c/")) 378 assert.Equal(t, "/a", cleanPath("/a/b/..")) 379 assert.Equal(t, "/a/b/c", cleanPath("/a/b/c")) 380 assert.Equal(t, "/", cleanPath("//")) 381 assert.Equal(t, "/a", cleanPath("/a/")) 382 assert.Equal(t, "/a", cleanPath("a/")) 383 assert.Equal(t, "/a/b/c", cleanPath("/a//b//c/")) 384 385 // filepath.ToSlash does not touch \ as char on unix systems 386 // so os.PathSeparator is used for windows compatible tests 387 bslash := string(os.PathSeparator) 388 assert.Equal(t, "/", cleanPath(bslash)) 389 assert.Equal(t, "/", cleanPath(bslash+bslash)) 390 assert.Equal(t, "/a", cleanPath(bslash+"a"+bslash)) 391 assert.Equal(t, "/a", cleanPath("a"+bslash)) 392 assert.Equal(t, "/a/b/c", 393 cleanPath(bslash+"a"+bslash+bslash+"b"+bslash+bslash+"c"+bslash)) 394} 395