1// Package common defines code shared among file transfer packages and protocols
2package common
3
4import (
5	"context"
6	"errors"
7	"fmt"
8	"net"
9	"net/http"
10	"net/url"
11	"os"
12	"os/exec"
13	"path/filepath"
14	"strconv"
15	"strings"
16	"sync"
17	"sync/atomic"
18	"time"
19
20	"github.com/pires/go-proxyproto"
21
22	"github.com/drakkan/sftpgo/v2/dataprovider"
23	"github.com/drakkan/sftpgo/v2/httpclient"
24	"github.com/drakkan/sftpgo/v2/logger"
25	"github.com/drakkan/sftpgo/v2/metric"
26	"github.com/drakkan/sftpgo/v2/util"
27	"github.com/drakkan/sftpgo/v2/vfs"
28)
29
30// constants
31const (
32	logSender         = "common"
33	uploadLogSender   = "Upload"
34	downloadLogSender = "Download"
35	renameLogSender   = "Rename"
36	rmdirLogSender    = "Rmdir"
37	mkdirLogSender    = "Mkdir"
38	symlinkLogSender  = "Symlink"
39	removeLogSender   = "Remove"
40	chownLogSender    = "Chown"
41	chmodLogSender    = "Chmod"
42	chtimesLogSender  = "Chtimes"
43	truncateLogSender = "Truncate"
44	operationDownload = "download"
45	operationUpload   = "upload"
46	operationDelete   = "delete"
47	// Pre-download action name
48	OperationPreDownload = "pre-download"
49	// Pre-upload action name
50	OperationPreUpload = "pre-upload"
51	operationPreDelete = "pre-delete"
52	operationRename    = "rename"
53	operationMkdir     = "mkdir"
54	operationRmdir     = "rmdir"
55	// SSH command action name
56	OperationSSHCmd          = "ssh_cmd"
57	chtimesFormat            = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
58	idleTimeoutCheckInterval = 3 * time.Minute
59)
60
61// Stat flags
62const (
63	StatAttrUIDGID = 1
64	StatAttrPerms  = 2
65	StatAttrTimes  = 4
66	StatAttrSize   = 8
67)
68
69// Transfer types
70const (
71	TransferUpload = iota
72	TransferDownload
73)
74
75// Supported protocols
76const (
77	ProtocolSFTP          = "SFTP"
78	ProtocolSCP           = "SCP"
79	ProtocolSSH           = "SSH"
80	ProtocolFTP           = "FTP"
81	ProtocolWebDAV        = "DAV"
82	ProtocolHTTP          = "HTTP"
83	ProtocolHTTPShare     = "HTTPShare"
84	ProtocolDataRetention = "DataRetention"
85)
86
87// Upload modes
88const (
89	UploadModeStandard = iota
90	UploadModeAtomic
91	UploadModeAtomicWithResume
92)
93
94func init() {
95	Connections.clients = clientsMap{
96		clients: make(map[string]int),
97	}
98}
99
100// errors definitions
101var (
102	ErrPermissionDenied     = errors.New("permission denied")
103	ErrNotExist             = errors.New("no such file or directory")
104	ErrOpUnsupported        = errors.New("operation unsupported")
105	ErrGenericFailure       = errors.New("failure")
106	ErrQuotaExceeded        = errors.New("denying write due to space limit")
107	ErrSkipPermissionsCheck = errors.New("permission check skipped")
108	ErrConnectionDenied     = errors.New("you are not allowed to connect")
109	ErrNoBinding            = errors.New("no binding configured")
110	ErrCrtRevoked           = errors.New("your certificate has been revoked")
111	ErrNoCredentials        = errors.New("no credential provided")
112	ErrInternalFailure      = errors.New("internal failure")
113	errNoTransfer           = errors.New("requested transfer not found")
114	errTransferMismatch     = errors.New("transfer mismatch")
115)
116
117var (
118	// Config is the configuration for the supported protocols
119	Config Configuration
120	// Connections is the list of active connections
121	Connections ActiveConnections
122	// QuotaScans is the list of active quota scans
123	QuotaScans            ActiveScans
124	idleTimeoutTicker     *time.Ticker
125	idleTimeoutTickerDone chan bool
126	supportedProtocols    = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
127		ProtocolHTTP, ProtocolHTTPShare}
128	disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
129	// the map key is the protocol, for each protocol we can have multiple rate limiters
130	rateLimiters map[string][]*rateLimiter
131)
132
133// Initialize sets the common configuration
134func Initialize(c Configuration) error {
135	Config = c
136	Config.idleLoginTimeout = 2 * time.Minute
137	Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
138	if Config.IdleTimeout > 0 {
139		startIdleTimeoutTicker(idleTimeoutCheckInterval)
140	}
141	Config.defender = nil
142	if c.DefenderConfig.Enabled {
143		defender, err := newInMemoryDefender(&c.DefenderConfig)
144		if err != nil {
145			return fmt.Errorf("defender initialization error: %v", err)
146		}
147		logger.Info(logSender, "", "defender initialized with config %+v", c.DefenderConfig)
148		Config.defender = defender
149	}
150	rateLimiters = make(map[string][]*rateLimiter)
151	for _, rlCfg := range c.RateLimitersConfig {
152		if rlCfg.isEnabled() {
153			if err := rlCfg.validate(); err != nil {
154				return fmt.Errorf("rate limiters initialization error: %v", err)
155			}
156			allowList, err := util.ParseAllowedIPAndRanges(rlCfg.AllowList)
157			if err != nil {
158				return fmt.Errorf("unable to parse rate limiter allow list %v: %v", rlCfg.AllowList, err)
159			}
160			rateLimiter := rlCfg.getLimiter()
161			rateLimiter.allowList = allowList
162			for _, protocol := range rlCfg.Protocols {
163				rateLimiters[protocol] = append(rateLimiters[protocol], rateLimiter)
164			}
165		}
166	}
167	vfs.SetTempPath(c.TempPath)
168	dataprovider.SetTempPath(c.TempPath)
169	return nil
170}
171
172// LimitRate blocks until all the configured rate limiters
173// allow one event to happen.
174// It returns an error if the time to wait exceeds the max
175// allowed delay
176func LimitRate(protocol, ip string) (time.Duration, error) {
177	for _, limiter := range rateLimiters[protocol] {
178		if delay, err := limiter.Wait(ip); err != nil {
179			logger.Debug(logSender, "", "protocol %v ip %v: %v", protocol, ip, err)
180			return delay, err
181		}
182	}
183	return 0, nil
184}
185
186// ReloadDefender reloads the defender's block and safe lists
187func ReloadDefender() error {
188	if Config.defender == nil {
189		return nil
190	}
191
192	return Config.defender.Reload()
193}
194
195// IsBanned returns true if the specified IP address is banned
196func IsBanned(ip string) bool {
197	if Config.defender == nil {
198		return false
199	}
200
201	return Config.defender.IsBanned(ip)
202}
203
204// GetDefenderBanTime returns the ban time for the given IP
205// or nil if the IP is not banned or the defender is disabled
206func GetDefenderBanTime(ip string) *time.Time {
207	if Config.defender == nil {
208		return nil
209	}
210
211	return Config.defender.GetBanTime(ip)
212}
213
214// GetDefenderHosts returns hosts that are banned or for which some violations have been detected
215func GetDefenderHosts() []*DefenderEntry {
216	if Config.defender == nil {
217		return nil
218	}
219
220	return Config.defender.GetHosts()
221}
222
223// GetDefenderHost returns a defender host by ip, if any
224func GetDefenderHost(ip string) (*DefenderEntry, error) {
225	if Config.defender == nil {
226		return nil, errors.New("defender is disabled")
227	}
228
229	return Config.defender.GetHost(ip)
230}
231
232// DeleteDefenderHost removes the specified IP address from the defender lists
233func DeleteDefenderHost(ip string) bool {
234	if Config.defender == nil {
235		return false
236	}
237
238	return Config.defender.DeleteHost(ip)
239}
240
241// GetDefenderScore returns the score for the given IP
242func GetDefenderScore(ip string) int {
243	if Config.defender == nil {
244		return 0
245	}
246
247	return Config.defender.GetScore(ip)
248}
249
250// AddDefenderEvent adds the specified defender event for the given IP
251func AddDefenderEvent(ip string, event HostEvent) {
252	if Config.defender == nil {
253		return
254	}
255
256	Config.defender.AddEvent(ip, event)
257}
258
259// the ticker cannot be started/stopped from multiple goroutines
260func startIdleTimeoutTicker(duration time.Duration) {
261	stopIdleTimeoutTicker()
262	idleTimeoutTicker = time.NewTicker(duration)
263	idleTimeoutTickerDone = make(chan bool)
264	go func() {
265		for {
266			select {
267			case <-idleTimeoutTickerDone:
268				return
269			case <-idleTimeoutTicker.C:
270				Connections.checkIdles()
271			}
272		}
273	}()
274}
275
276func stopIdleTimeoutTicker() {
277	if idleTimeoutTicker != nil {
278		idleTimeoutTicker.Stop()
279		idleTimeoutTickerDone <- true
280		idleTimeoutTicker = nil
281	}
282}
283
284// ActiveTransfer defines the interface for the current active transfers
285type ActiveTransfer interface {
286	GetID() uint64
287	GetType() int
288	GetSize() int64
289	GetVirtualPath() string
290	GetStartTime() time.Time
291	SignalClose()
292	Truncate(fsPath string, size int64) (int64, error)
293	GetRealFsPath(fsPath string) string
294	SetTimes(fsPath string, atime time.Time, mtime time.Time) bool
295}
296
297// ActiveConnection defines the interface for the current active connections
298type ActiveConnection interface {
299	GetID() string
300	GetUsername() string
301	GetLocalAddress() string
302	GetRemoteAddress() string
303	GetClientVersion() string
304	GetProtocol() string
305	GetConnectionTime() time.Time
306	GetLastActivity() time.Time
307	GetCommand() string
308	Disconnect() error
309	AddTransfer(t ActiveTransfer)
310	RemoveTransfer(t ActiveTransfer)
311	GetTransfers() []ConnectionTransfer
312	CloseFS() error
313}
314
315// StatAttributes defines the attributes for set stat commands
316type StatAttributes struct {
317	Mode  os.FileMode
318	Atime time.Time
319	Mtime time.Time
320	UID   int
321	GID   int
322	Flags int
323	Size  int64
324}
325
326// ConnectionTransfer defines the trasfer details to expose
327type ConnectionTransfer struct {
328	ID            uint64 `json:"-"`
329	OperationType string `json:"operation_type"`
330	StartTime     int64  `json:"start_time"`
331	Size          int64  `json:"size"`
332	VirtualPath   string `json:"path"`
333}
334
335func (t *ConnectionTransfer) getConnectionTransferAsString() string {
336	result := ""
337	switch t.OperationType {
338	case operationUpload:
339		result += "UL "
340	case operationDownload:
341		result += "DL "
342	}
343	result += fmt.Sprintf("%#v ", t.VirtualPath)
344	if t.Size > 0 {
345		elapsed := time.Since(util.GetTimeFromMsecSinceEpoch(t.StartTime))
346		speed := float64(t.Size) / float64(util.GetTimeAsMsSinceEpoch(time.Now())-t.StartTime)
347		result += fmt.Sprintf("Size: %#v Elapsed: %#v Speed: \"%.1f KB/s\"", util.ByteCountIEC(t.Size),
348			util.GetDurationAsString(elapsed), speed)
349	}
350	return result
351}
352
353// Configuration defines configuration parameters common to all supported protocols
354type Configuration struct {
355	// Maximum idle timeout as minutes. If a client is idle for a time that exceeds this setting it will be disconnected.
356	// 0 means disabled
357	IdleTimeout int `json:"idle_timeout" mapstructure:"idle_timeout"`
358	// UploadMode 0 means standard, the files are uploaded directly to the requested path.
359	// 1 means atomic: the files are uploaded to a temporary path and renamed to the requested path
360	// when the client ends the upload. Atomic mode avoid problems such as a web server that
361	// serves partial files when the files are being uploaded.
362	// In atomic mode if there is an upload error the temporary file is deleted and so the requested
363	// upload path will not contain a partial file.
364	// 2 means atomic with resume support: as atomic but if there is an upload error the temporary
365	// file is renamed to the requested path and not deleted, this way a client can reconnect and resume
366	// the upload.
367	UploadMode int `json:"upload_mode" mapstructure:"upload_mode"`
368	// Actions to execute for SFTP file operations and SSH commands
369	Actions ProtocolActions `json:"actions" mapstructure:"actions"`
370	// SetstatMode 0 means "normal mode": requests for changing permissions and owner/group are executed.
371	// 1 means "ignore mode": requests for changing permissions and owner/group are silently ignored.
372	// 2 means "ignore mode for cloud fs": requests for changing permissions and owner/group/time are
373	// silently ignored for cloud based filesystem such as S3, GCS, Azure Blob
374	SetstatMode int `json:"setstat_mode" mapstructure:"setstat_mode"`
375	// TempPath defines the path for temporary files such as those used for atomic uploads or file pipes.
376	// If you set this option you must make sure that the defined path exists, is accessible for writing
377	// by the user running SFTPGo, and is on the same filesystem as the users home directories otherwise
378	// the renaming for atomic uploads will become a copy and therefore may take a long time.
379	// The temporary files are not namespaced. The default is generally fine. Leave empty for the default.
380	TempPath string `json:"temp_path" mapstructure:"temp_path"`
381	// Support for HAProxy PROXY protocol.
382	// If you are running SFTPGo behind a proxy server such as HAProxy, AWS ELB or NGNIX, you can enable
383	// the proxy protocol. It provides a convenient way to safely transport connection information
384	// such as a client's address across multiple layers of NAT or TCP proxies to get the real
385	// client IP address instead of the proxy IP. Both protocol versions 1 and 2 are supported.
386	// - 0 means disabled
387	// - 1 means proxy protocol enabled. Proxy header will be used and requests without proxy header will be accepted.
388	// - 2 means proxy protocol required. Proxy header will be used and requests without proxy header will be rejected.
389	// If the proxy protocol is enabled in SFTPGo then you have to enable the protocol in your proxy configuration too,
390	// for example for HAProxy add "send-proxy" or "send-proxy-v2" to each server configuration line.
391	ProxyProtocol int `json:"proxy_protocol" mapstructure:"proxy_protocol"`
392	// List of IP addresses and IP ranges allowed to send the proxy header.
393	// If proxy protocol is set to 1 and we receive a proxy header from an IP that is not in the list then the
394	// connection will be accepted and the header will be ignored.
395	// If proxy protocol is set to 2 and we receive a proxy header from an IP that is not in the list then the
396	// connection will be rejected.
397	ProxyAllowed []string `json:"proxy_allowed" mapstructure:"proxy_allowed"`
398	// Absolute path to an external program or an HTTP URL to invoke as soon as SFTPGo starts.
399	// If you define an HTTP URL it will be invoked using a `GET` request.
400	// Please note that SFTPGo services may not yet be available when this hook is run.
401	// Leave empty do disable.
402	StartupHook string `json:"startup_hook" mapstructure:"startup_hook"`
403	// Absolute path to an external program or an HTTP URL to invoke after a user connects
404	// and before he tries to login. It allows you to reject the connection based on the source
405	// ip address. Leave empty do disable.
406	PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"`
407	// Absolute path to an external program or an HTTP URL to invoke after an SSH/FTP connection ends.
408	// Leave empty do disable.
409	PostDisconnectHook string `json:"post_disconnect_hook" mapstructure:"post_disconnect_hook"`
410	// Absolute path to an external program or an HTTP URL to invoke after a data retention check completes.
411	// Leave empty do disable.
412	DataRetentionHook string `json:"data_retention_hook" mapstructure:"data_retention_hook"`
413	// Maximum number of concurrent client connections. 0 means unlimited
414	MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"`
415	// Maximum number of concurrent client connections from the same host (IP). 0 means unlimited
416	MaxPerHostConnections int `json:"max_per_host_connections" mapstructure:"max_per_host_connections"`
417	// Defender configuration
418	DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"`
419	// Rate limiter configurations
420	RateLimitersConfig    []RateLimiterConfig `json:"rate_limiters" mapstructure:"rate_limiters"`
421	idleTimeoutAsDuration time.Duration
422	idleLoginTimeout      time.Duration
423	defender              Defender
424}
425
426// IsAtomicUploadEnabled returns true if atomic upload is enabled
427func (c *Configuration) IsAtomicUploadEnabled() bool {
428	return c.UploadMode == UploadModeAtomic || c.UploadMode == UploadModeAtomicWithResume
429}
430
431// GetProxyListener returns a wrapper for the given listener that supports the
432// HAProxy Proxy Protocol
433func (c *Configuration) GetProxyListener(listener net.Listener) (*proxyproto.Listener, error) {
434	var err error
435	if c.ProxyProtocol > 0 {
436		var policyFunc func(upstream net.Addr) (proxyproto.Policy, error)
437		if c.ProxyProtocol == 1 && len(c.ProxyAllowed) > 0 {
438			policyFunc, err = proxyproto.LaxWhiteListPolicy(c.ProxyAllowed)
439			if err != nil {
440				return nil, err
441			}
442		}
443		if c.ProxyProtocol == 2 {
444			if len(c.ProxyAllowed) == 0 {
445				policyFunc = func(upstream net.Addr) (proxyproto.Policy, error) {
446					return proxyproto.REQUIRE, nil
447				}
448			} else {
449				policyFunc, err = proxyproto.StrictWhiteListPolicy(c.ProxyAllowed)
450				if err != nil {
451					return nil, err
452				}
453			}
454		}
455		return &proxyproto.Listener{
456			Listener:          listener,
457			Policy:            policyFunc,
458			ReadHeaderTimeout: 5 * time.Second,
459		}, nil
460	}
461	return nil, errors.New("proxy protocol not configured")
462}
463
464// ExecuteStartupHook runs the startup hook if defined
465func (c *Configuration) ExecuteStartupHook() error {
466	if c.StartupHook == "" {
467		return nil
468	}
469	if strings.HasPrefix(c.StartupHook, "http") {
470		var url *url.URL
471		url, err := url.Parse(c.StartupHook)
472		if err != nil {
473			logger.Warn(logSender, "", "Invalid startup hook %#v: %v", c.StartupHook, err)
474			return err
475		}
476		startTime := time.Now()
477		resp, err := httpclient.RetryableGet(url.String())
478		if err != nil {
479			logger.Warn(logSender, "", "Error executing startup hook: %v", err)
480			return err
481		}
482		defer resp.Body.Close()
483		logger.Debug(logSender, "", "Startup hook executed, elapsed: %v, response code: %v", time.Since(startTime), resp.StatusCode)
484		return nil
485	}
486	if !filepath.IsAbs(c.StartupHook) {
487		err := fmt.Errorf("invalid startup hook %#v", c.StartupHook)
488		logger.Warn(logSender, "", "Invalid startup hook %#v", c.StartupHook)
489		return err
490	}
491	startTime := time.Now()
492	ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
493	defer cancel()
494	cmd := exec.CommandContext(ctx, c.StartupHook)
495	err := cmd.Run()
496	logger.Debug(logSender, "", "Startup hook executed, elapsed: %v, error: %v", time.Since(startTime), err)
497	return nil
498}
499
500func (c *Configuration) executePostDisconnectHook(remoteAddr, protocol, username, connID string, connectionTime time.Time) {
501	ipAddr := util.GetIPFromRemoteAddress(remoteAddr)
502	connDuration := int64(time.Since(connectionTime) / time.Millisecond)
503
504	if strings.HasPrefix(c.PostDisconnectHook, "http") {
505		var url *url.URL
506		url, err := url.Parse(c.PostDisconnectHook)
507		if err != nil {
508			logger.Warn(protocol, connID, "Invalid post disconnect hook %#v: %v", c.PostDisconnectHook, err)
509			return
510		}
511		q := url.Query()
512		q.Add("ip", ipAddr)
513		q.Add("protocol", protocol)
514		q.Add("username", username)
515		q.Add("connection_duration", strconv.FormatInt(connDuration, 10))
516		url.RawQuery = q.Encode()
517		startTime := time.Now()
518		resp, err := httpclient.RetryableGet(url.String())
519		respCode := 0
520		if err == nil {
521			respCode = resp.StatusCode
522			resp.Body.Close()
523		}
524		logger.Debug(protocol, connID, "Post disconnect hook response code: %v, elapsed: %v, err: %v",
525			respCode, time.Since(startTime), err)
526		return
527	}
528	if !filepath.IsAbs(c.PostDisconnectHook) {
529		logger.Debug(protocol, connID, "invalid post disconnect hook %#v", c.PostDisconnectHook)
530		return
531	}
532	ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
533	defer cancel()
534
535	startTime := time.Now()
536	cmd := exec.CommandContext(ctx, c.PostDisconnectHook)
537	cmd.Env = append(os.Environ(),
538		fmt.Sprintf("SFTPGO_CONNECTION_IP=%v", ipAddr),
539		fmt.Sprintf("SFTPGO_CONNECTION_USERNAME=%v", username),
540		fmt.Sprintf("SFTPGO_CONNECTION_DURATION=%v", connDuration),
541		fmt.Sprintf("SFTPGO_CONNECTION_PROTOCOL=%v", protocol))
542	err := cmd.Run()
543	logger.Debug(protocol, connID, "Post disconnect hook executed, elapsed: %v error: %v", time.Since(startTime), err)
544}
545
546func (c *Configuration) checkPostDisconnectHook(remoteAddr, protocol, username, connID string, connectionTime time.Time) {
547	if c.PostDisconnectHook == "" {
548		return
549	}
550	if !util.IsStringInSlice(protocol, disconnHookProtocols) {
551		return
552	}
553	go c.executePostDisconnectHook(remoteAddr, protocol, username, connID, connectionTime)
554}
555
556// ExecutePostConnectHook executes the post connect hook if defined
557func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error {
558	if c.PostConnectHook == "" {
559		return nil
560	}
561	if strings.HasPrefix(c.PostConnectHook, "http") {
562		var url *url.URL
563		url, err := url.Parse(c.PostConnectHook)
564		if err != nil {
565			logger.Warn(protocol, "", "Login from ip %#v denied, invalid post connect hook %#v: %v",
566				ipAddr, c.PostConnectHook, err)
567			return err
568		}
569		q := url.Query()
570		q.Add("ip", ipAddr)
571		q.Add("protocol", protocol)
572		url.RawQuery = q.Encode()
573
574		resp, err := httpclient.RetryableGet(url.String())
575		if err != nil {
576			logger.Warn(protocol, "", "Login from ip %#v denied, error executing post connect hook: %v", ipAddr, err)
577			return err
578		}
579		defer resp.Body.Close()
580		if resp.StatusCode != http.StatusOK {
581			logger.Warn(protocol, "", "Login from ip %#v denied, post connect hook response code: %v", ipAddr, resp.StatusCode)
582			return errUnexpectedHTTResponse
583		}
584		return nil
585	}
586	if !filepath.IsAbs(c.PostConnectHook) {
587		err := fmt.Errorf("invalid post connect hook %#v", c.PostConnectHook)
588		logger.Warn(protocol, "", "Login from ip %#v denied: %v", ipAddr, err)
589		return err
590	}
591	ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
592	defer cancel()
593	cmd := exec.CommandContext(ctx, c.PostConnectHook)
594	cmd.Env = append(os.Environ(),
595		fmt.Sprintf("SFTPGO_CONNECTION_IP=%v", ipAddr),
596		fmt.Sprintf("SFTPGO_CONNECTION_PROTOCOL=%v", protocol))
597	err := cmd.Run()
598	if err != nil {
599		logger.Warn(protocol, "", "Login from ip %#v denied, connect hook error: %v", ipAddr, err)
600	}
601	return err
602}
603
604// SSHConnection defines an ssh connection.
605// Each SSH connection can open several channels for SFTP or SSH commands
606type SSHConnection struct {
607	id           string
608	conn         net.Conn
609	lastActivity int64
610}
611
612// NewSSHConnection returns a new SSHConnection
613func NewSSHConnection(id string, conn net.Conn) *SSHConnection {
614	return &SSHConnection{
615		id:           id,
616		conn:         conn,
617		lastActivity: time.Now().UnixNano(),
618	}
619}
620
621// GetID returns the ID for this SSHConnection
622func (c *SSHConnection) GetID() string {
623	return c.id
624}
625
626// UpdateLastActivity updates last activity for this connection
627func (c *SSHConnection) UpdateLastActivity() {
628	atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano())
629}
630
631// GetLastActivity returns the last connection activity
632func (c *SSHConnection) GetLastActivity() time.Time {
633	return time.Unix(0, atomic.LoadInt64(&c.lastActivity))
634}
635
636// Close closes the underlying network connection
637func (c *SSHConnection) Close() error {
638	return c.conn.Close()
639}
640
641// ActiveConnections holds the currect active connections with the associated transfers
642type ActiveConnections struct {
643	// clients contains both authenticated and estabilished connections and the ones waiting
644	// for authentication
645	clients clientsMap
646	sync.RWMutex
647	connections    []ActiveConnection
648	sshConnections []*SSHConnection
649}
650
651// GetActiveSessions returns the number of active sessions for the given username.
652// We return the open sessions for any protocol
653func (conns *ActiveConnections) GetActiveSessions(username string) int {
654	conns.RLock()
655	defer conns.RUnlock()
656
657	numSessions := 0
658	for _, c := range conns.connections {
659		if c.GetUsername() == username {
660			numSessions++
661		}
662	}
663	return numSessions
664}
665
666// Add adds a new connection to the active ones
667func (conns *ActiveConnections) Add(c ActiveConnection) {
668	conns.Lock()
669	defer conns.Unlock()
670
671	conns.connections = append(conns.connections, c)
672	metric.UpdateActiveConnectionsSize(len(conns.connections))
673	logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %#v, remote address %#v, num open connections: %v",
674		c.GetLocalAddress(), c.GetRemoteAddress(), len(conns.connections))
675}
676
677// Swap replaces an existing connection with the given one.
678// This method is useful if you have to change some connection details
679// for example for FTP is used to update the connection once the user
680// authenticates
681func (conns *ActiveConnections) Swap(c ActiveConnection) error {
682	conns.Lock()
683	defer conns.Unlock()
684
685	for idx, conn := range conns.connections {
686		if conn.GetID() == c.GetID() {
687			err := conn.CloseFS()
688			conns.connections[idx] = c
689			logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err)
690			conn = nil
691			return nil
692		}
693	}
694	return errors.New("connection to swap not found")
695}
696
697// Remove removes a connection from the active ones
698func (conns *ActiveConnections) Remove(connectionID string) {
699	conns.Lock()
700	defer conns.Unlock()
701
702	for idx, conn := range conns.connections {
703		if conn.GetID() == connectionID {
704			err := conn.CloseFS()
705			lastIdx := len(conns.connections) - 1
706			conns.connections[idx] = conns.connections[lastIdx]
707			conns.connections[lastIdx] = nil
708			conns.connections = conns.connections[:lastIdx]
709			metric.UpdateActiveConnectionsSize(lastIdx)
710			logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v",
711				conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
712			Config.checkPostDisconnectHook(conn.GetRemoteAddress(), conn.GetProtocol(), conn.GetUsername(),
713				conn.GetID(), conn.GetConnectionTime())
714			return
715		}
716	}
717	logger.Warn(logSender, "", "connection id %#v to remove not found!", connectionID)
718}
719
720// Close closes an active connection.
721// It returns true on success
722func (conns *ActiveConnections) Close(connectionID string) bool {
723	conns.RLock()
724	result := false
725
726	for _, c := range conns.connections {
727		if c.GetID() == connectionID {
728			defer func(conn ActiveConnection) {
729				err := conn.Disconnect()
730				logger.Debug(conn.GetProtocol(), conn.GetID(), "close connection requested, close err: %v", err)
731			}(c)
732			result = true
733			break
734		}
735	}
736
737	conns.RUnlock()
738	return result
739}
740
741// AddSSHConnection adds a new ssh connection to the active ones
742func (conns *ActiveConnections) AddSSHConnection(c *SSHConnection) {
743	conns.Lock()
744	defer conns.Unlock()
745
746	conns.sshConnections = append(conns.sshConnections, c)
747	logger.Debug(logSender, c.GetID(), "ssh connection added, num open connections: %v", len(conns.sshConnections))
748}
749
750// RemoveSSHConnection removes a connection from the active ones
751func (conns *ActiveConnections) RemoveSSHConnection(connectionID string) {
752	conns.Lock()
753	defer conns.Unlock()
754
755	for idx, conn := range conns.sshConnections {
756		if conn.GetID() == connectionID {
757			lastIdx := len(conns.sshConnections) - 1
758			conns.sshConnections[idx] = conns.sshConnections[lastIdx]
759			conns.sshConnections[lastIdx] = nil
760			conns.sshConnections = conns.sshConnections[:lastIdx]
761			logger.Debug(logSender, conn.GetID(), "ssh connection removed, num open ssh connections: %v", lastIdx)
762			return
763		}
764	}
765	logger.Warn(logSender, "", "ssh connection to remove with id %#v not found!", connectionID)
766}
767
768func (conns *ActiveConnections) checkIdles() {
769	conns.RLock()
770
771	for _, sshConn := range conns.sshConnections {
772		idleTime := time.Since(sshConn.GetLastActivity())
773		if idleTime > Config.idleTimeoutAsDuration {
774			// we close the an ssh connection if it has no active connections associated
775			idToMatch := fmt.Sprintf("_%v_", sshConn.GetID())
776			toClose := true
777			for _, conn := range conns.connections {
778				if strings.Contains(conn.GetID(), idToMatch) {
779					toClose = false
780					break
781				}
782			}
783			if toClose {
784				defer func(c *SSHConnection) {
785					err := c.Close()
786					logger.Debug(logSender, c.GetID(), "close idle SSH connection, idle time: %v, close err: %v",
787						time.Since(c.GetLastActivity()), err)
788				}(sshConn)
789			}
790		}
791	}
792
793	for _, c := range conns.connections {
794		idleTime := time.Since(c.GetLastActivity())
795		isUnauthenticatedFTPUser := (c.GetProtocol() == ProtocolFTP && c.GetUsername() == "")
796
797		if idleTime > Config.idleTimeoutAsDuration || (isUnauthenticatedFTPUser && idleTime > Config.idleLoginTimeout) {
798			defer func(conn ActiveConnection, isFTPNoAuth bool) {
799				err := conn.Disconnect()
800				logger.Debug(conn.GetProtocol(), conn.GetID(), "close idle connection, idle time: %v, username: %#v close err: %v",
801					time.Since(conn.GetLastActivity()), conn.GetUsername(), err)
802				if isFTPNoAuth {
803					ip := util.GetIPFromRemoteAddress(c.GetRemoteAddress())
804					logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, c.GetProtocol(), "client idle")
805					metric.AddNoAuthTryed()
806					AddDefenderEvent(ip, HostEventNoLoginTried)
807					dataprovider.ExecutePostLoginHook(&dataprovider.User{}, dataprovider.LoginMethodNoAuthTryed, ip, c.GetProtocol(),
808						dataprovider.ErrNoAuthTryed)
809				}
810			}(c, isUnauthenticatedFTPUser)
811		}
812	}
813
814	conns.RUnlock()
815}
816
817// AddClientConnection stores a new client connection
818func (conns *ActiveConnections) AddClientConnection(ipAddr string) {
819	conns.clients.add(ipAddr)
820}
821
822// RemoveClientConnection removes a disconnected client from the tracked ones
823func (conns *ActiveConnections) RemoveClientConnection(ipAddr string) {
824	conns.clients.remove(ipAddr)
825}
826
827// GetClientConnections returns the total number of client connections
828func (conns *ActiveConnections) GetClientConnections() int32 {
829	return conns.clients.getTotal()
830}
831
832// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
833func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool {
834	if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {
835		return true
836	}
837
838	if Config.MaxPerHostConnections > 0 {
839		if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections {
840			logger.Debug(logSender, "", "active connections from %v %v/%v", ipAddr, total, Config.MaxPerHostConnections)
841			AddDefenderEvent(ipAddr, HostEventLimitExceeded)
842			return false
843		}
844	}
845
846	if Config.MaxTotalConnections > 0 {
847		if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) {
848			logger.Debug(logSender, "", "active client connections %v/%v", total, Config.MaxTotalConnections)
849			return false
850		}
851
852		// on a single SFTP connection we could have multiple SFTP channels or commands
853		// so we check the estabilished connections too
854
855		conns.RLock()
856		defer conns.RUnlock()
857
858		return len(conns.connections) < Config.MaxTotalConnections
859	}
860
861	return true
862}
863
864// GetStats returns stats for active connections
865func (conns *ActiveConnections) GetStats() []*ConnectionStatus {
866	conns.RLock()
867	defer conns.RUnlock()
868
869	stats := make([]*ConnectionStatus, 0, len(conns.connections))
870	for _, c := range conns.connections {
871		stat := &ConnectionStatus{
872			Username:       c.GetUsername(),
873			ConnectionID:   c.GetID(),
874			ClientVersion:  c.GetClientVersion(),
875			RemoteAddress:  c.GetRemoteAddress(),
876			ConnectionTime: util.GetTimeAsMsSinceEpoch(c.GetConnectionTime()),
877			LastActivity:   util.GetTimeAsMsSinceEpoch(c.GetLastActivity()),
878			Protocol:       c.GetProtocol(),
879			Command:        c.GetCommand(),
880			Transfers:      c.GetTransfers(),
881		}
882		stats = append(stats, stat)
883	}
884	return stats
885}
886
887// ConnectionStatus returns the status for an active connection
888type ConnectionStatus struct {
889	// Logged in username
890	Username string `json:"username"`
891	// Unique identifier for the connection
892	ConnectionID string `json:"connection_id"`
893	// client's version string
894	ClientVersion string `json:"client_version,omitempty"`
895	// Remote address for this connection
896	RemoteAddress string `json:"remote_address"`
897	// Connection time as unix timestamp in milliseconds
898	ConnectionTime int64 `json:"connection_time"`
899	// Last activity as unix timestamp in milliseconds
900	LastActivity int64 `json:"last_activity"`
901	// Protocol for this connection
902	Protocol string `json:"protocol"`
903	// active uploads/downloads
904	Transfers []ConnectionTransfer `json:"active_transfers,omitempty"`
905	// SSH command or WebDAV method
906	Command string `json:"command,omitempty"`
907}
908
909// GetConnectionDuration returns the connection duration as string
910func (c *ConnectionStatus) GetConnectionDuration() string {
911	elapsed := time.Since(util.GetTimeFromMsecSinceEpoch(c.ConnectionTime))
912	return util.GetDurationAsString(elapsed)
913}
914
915// GetConnectionInfo returns connection info.
916// Protocol,Client Version and RemoteAddress are returned.
917func (c *ConnectionStatus) GetConnectionInfo() string {
918	var result strings.Builder
919
920	result.WriteString(fmt.Sprintf("%v. Client: %#v From: %#v", c.Protocol, c.ClientVersion, c.RemoteAddress))
921
922	if c.Command == "" {
923		return result.String()
924	}
925
926	switch c.Protocol {
927	case ProtocolSSH, ProtocolFTP:
928		result.WriteString(fmt.Sprintf(". Command: %#v", c.Command))
929	case ProtocolWebDAV:
930		result.WriteString(fmt.Sprintf(". Method: %#v", c.Command))
931	}
932
933	return result.String()
934}
935
936// GetTransfersAsString returns the active transfers as string
937func (c *ConnectionStatus) GetTransfersAsString() string {
938	result := ""
939	for _, t := range c.Transfers {
940		if result != "" {
941			result += ". "
942		}
943		result += t.getConnectionTransferAsString()
944	}
945	return result
946}
947
948// ActiveQuotaScan defines an active quota scan for a user home dir
949type ActiveQuotaScan struct {
950	// Username to which the quota scan refers
951	Username string `json:"username"`
952	// quota scan start time as unix timestamp in milliseconds
953	StartTime int64 `json:"start_time"`
954}
955
956// ActiveVirtualFolderQuotaScan defines an active quota scan for a virtual folder
957type ActiveVirtualFolderQuotaScan struct {
958	// folder name to which the quota scan refers
959	Name string `json:"name"`
960	// quota scan start time as unix timestamp in milliseconds
961	StartTime int64 `json:"start_time"`
962}
963
964// ActiveScans holds the active quota scans
965type ActiveScans struct {
966	sync.RWMutex
967	UserScans   []ActiveQuotaScan
968	FolderScans []ActiveVirtualFolderQuotaScan
969}
970
971// GetUsersQuotaScans returns the active quota scans for users home directories
972func (s *ActiveScans) GetUsersQuotaScans() []ActiveQuotaScan {
973	s.RLock()
974	defer s.RUnlock()
975
976	scans := make([]ActiveQuotaScan, len(s.UserScans))
977	copy(scans, s.UserScans)
978	return scans
979}
980
981// AddUserQuotaScan adds a user to the ones with active quota scans.
982// Returns false if the user has a quota scan already running
983func (s *ActiveScans) AddUserQuotaScan(username string) bool {
984	s.Lock()
985	defer s.Unlock()
986
987	for _, scan := range s.UserScans {
988		if scan.Username == username {
989			return false
990		}
991	}
992	s.UserScans = append(s.UserScans, ActiveQuotaScan{
993		Username:  username,
994		StartTime: util.GetTimeAsMsSinceEpoch(time.Now()),
995	})
996	return true
997}
998
999// RemoveUserQuotaScan removes a user from the ones with active quota scans.
1000// Returns false if the user has no active quota scans
1001func (s *ActiveScans) RemoveUserQuotaScan(username string) bool {
1002	s.Lock()
1003	defer s.Unlock()
1004
1005	for idx, scan := range s.UserScans {
1006		if scan.Username == username {
1007			lastIdx := len(s.UserScans) - 1
1008			s.UserScans[idx] = s.UserScans[lastIdx]
1009			s.UserScans = s.UserScans[:lastIdx]
1010			return true
1011		}
1012	}
1013
1014	return false
1015}
1016
1017// GetVFoldersQuotaScans returns the active quota scans for virtual folders
1018func (s *ActiveScans) GetVFoldersQuotaScans() []ActiveVirtualFolderQuotaScan {
1019	s.RLock()
1020	defer s.RUnlock()
1021	scans := make([]ActiveVirtualFolderQuotaScan, len(s.FolderScans))
1022	copy(scans, s.FolderScans)
1023	return scans
1024}
1025
1026// AddVFolderQuotaScan adds a virtual folder to the ones with active quota scans.
1027// Returns false if the folder has a quota scan already running
1028func (s *ActiveScans) AddVFolderQuotaScan(folderName string) bool {
1029	s.Lock()
1030	defer s.Unlock()
1031
1032	for _, scan := range s.FolderScans {
1033		if scan.Name == folderName {
1034			return false
1035		}
1036	}
1037	s.FolderScans = append(s.FolderScans, ActiveVirtualFolderQuotaScan{
1038		Name:      folderName,
1039		StartTime: util.GetTimeAsMsSinceEpoch(time.Now()),
1040	})
1041	return true
1042}
1043
1044// RemoveVFolderQuotaScan removes a folder from the ones with active quota scans.
1045// Returns false if the folder has no active quota scans
1046func (s *ActiveScans) RemoveVFolderQuotaScan(folderName string) bool {
1047	s.Lock()
1048	defer s.Unlock()
1049
1050	for idx, scan := range s.FolderScans {
1051		if scan.Name == folderName {
1052			lastIdx := len(s.FolderScans) - 1
1053			s.FolderScans[idx] = s.FolderScans[lastIdx]
1054			s.FolderScans = s.FolderScans[:lastIdx]
1055			return true
1056		}
1057	}
1058
1059	return false
1060}
1061