1package sftp
2
3import (
4	"errors"
5	"io"
6	"syscall"
7	"testing"
8
9	"github.com/stretchr/testify/assert"
10)
11
12func TestErrFxCode(t *testing.T) {
13	table := []struct {
14		err error
15		fx  fxerr
16	}{
17		{err: errors.New("random error"), fx: ErrSSHFxFailure},
18		{err: EBADF, fx: ErrSSHFxFailure},
19		{err: syscall.ENOENT, fx: ErrSSHFxNoSuchFile},
20		{err: syscall.EPERM, fx: ErrSSHFxPermissionDenied},
21		{err: io.EOF, fx: ErrSSHFxEOF},
22	}
23	for _, tt := range table {
24		statusErr := statusFromError(1, tt.err).StatusError
25		assert.Equal(t, statusErr.FxCode(), tt.fx)
26	}
27}
28
29func TestSupportedExtensions(t *testing.T) {
30	for _, supportedExtension := range supportedSFTPExtensions {
31		_, err := getSupportedExtensionByName(supportedExtension.Name)
32		assert.NoError(t, err)
33	}
34	_, err := getSupportedExtensionByName("invalid@example.com")
35	assert.Error(t, err)
36}
37
38func TestExtensions(t *testing.T) {
39	var supportedExtensions []string
40	for _, supportedExtension := range supportedSFTPExtensions {
41		supportedExtensions = append(supportedExtensions, supportedExtension.Name)
42	}
43
44	testSFTPExtensions := []string{"hardlink@openssh.com"}
45	expectedSFTPExtensions := []sshExtensionPair{
46		{"hardlink@openssh.com", "1"},
47	}
48	err := SetSFTPExtensions(testSFTPExtensions...)
49	assert.NoError(t, err)
50	assert.Equal(t, expectedSFTPExtensions, sftpExtensions)
51
52	invalidSFTPExtensions := []string{"invalid@example.com"}
53	err = SetSFTPExtensions(invalidSFTPExtensions...)
54	assert.Error(t, err)
55	assert.Equal(t, expectedSFTPExtensions, sftpExtensions)
56
57	emptySFTPExtensions := []string{}
58	expectedSFTPExtensions = []sshExtensionPair{}
59	err = SetSFTPExtensions(emptySFTPExtensions...)
60	assert.NoError(t, err)
61	assert.Equal(t, expectedSFTPExtensions, sftpExtensions)
62
63	// if we only have an invalid extension nothing will be modified.
64	invalidSFTPExtensions = []string{
65		"hardlink@openssh.com",
66		"invalid@example.com",
67	}
68	err = SetSFTPExtensions(invalidSFTPExtensions...)
69	assert.Error(t, err)
70	assert.Equal(t, expectedSFTPExtensions, sftpExtensions)
71
72	err = SetSFTPExtensions(supportedExtensions...)
73	assert.Equal(t, supportedSFTPExtensions, sftpExtensions)
74}
75