1package common
2
3import (
4	"encoding/hex"
5	"encoding/json"
6	"fmt"
7	"net"
8	"os"
9	"sort"
10	"sync"
11	"time"
12
13	"github.com/yl2chen/cidranger"
14
15	"github.com/drakkan/sftpgo/v2/logger"
16	"github.com/drakkan/sftpgo/v2/util"
17)
18
19// HostEvent is the enumerable for the supported host events
20type HostEvent int
21
22// Supported host events
23const (
24	HostEventLoginFailed HostEvent = iota
25	HostEventUserNotFound
26	HostEventNoLoginTried
27	HostEventLimitExceeded
28)
29
30// DefenderEntry defines a defender entry
31type DefenderEntry struct {
32	IP      string    `json:"ip"`
33	Score   int       `json:"score,omitempty"`
34	BanTime time.Time `json:"ban_time,omitempty"`
35}
36
37// GetID returns an unique ID for a defender entry
38func (d *DefenderEntry) GetID() string {
39	return hex.EncodeToString([]byte(d.IP))
40}
41
42// GetBanTime returns the ban time for a defender entry as string
43func (d *DefenderEntry) GetBanTime() string {
44	if d.BanTime.IsZero() {
45		return ""
46	}
47	return d.BanTime.UTC().Format(time.RFC3339)
48}
49
50// MarshalJSON returns the JSON encoding of a DefenderEntry.
51func (d *DefenderEntry) MarshalJSON() ([]byte, error) {
52	return json.Marshal(&struct {
53		ID      string `json:"id"`
54		IP      string `json:"ip"`
55		Score   int    `json:"score,omitempty"`
56		BanTime string `json:"ban_time,omitempty"`
57	}{
58		ID:      d.GetID(),
59		IP:      d.IP,
60		Score:   d.Score,
61		BanTime: d.GetBanTime(),
62	})
63}
64
65// Defender defines the interface that a defender must implements
66type Defender interface {
67	GetHosts() []*DefenderEntry
68	GetHost(ip string) (*DefenderEntry, error)
69	AddEvent(ip string, event HostEvent)
70	IsBanned(ip string) bool
71	GetBanTime(ip string) *time.Time
72	GetScore(ip string) int
73	DeleteHost(ip string) bool
74	Reload() error
75}
76
77// DefenderConfig defines the "defender" configuration
78type DefenderConfig struct {
79	// Set to true to enable the defender
80	Enabled bool `json:"enabled" mapstructure:"enabled"`
81	// BanTime is the number of minutes that a host is banned
82	BanTime int `json:"ban_time" mapstructure:"ban_time"`
83	// Percentage increase of the ban time if a banned host tries to connect again
84	BanTimeIncrement int `json:"ban_time_increment" mapstructure:"ban_time_increment"`
85	// Threshold value for banning a client
86	Threshold int `json:"threshold" mapstructure:"threshold"`
87	// Score for invalid login attempts, eg. non-existent user accounts or
88	// client disconnected for inactivity without authentication attempts
89	ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
90	// Score for valid login attempts, eg. user accounts that exist
91	ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
92	// Score for limit exceeded events, generated from the rate limiters or for max connections
93	// per-host exceeded
94	ScoreLimitExceeded int `json:"score_limit_exceeded" mapstructure:"score_limit_exceeded"`
95	// Defines the time window, in minutes, for tracking client errors.
96	// A host is banned if it has exceeded the defined threshold during
97	// the last observation time minutes
98	ObservationTime int `json:"observation_time" mapstructure:"observation_time"`
99	// The number of banned IPs and host scores kept in memory will vary between the
100	// soft and hard limit
101	EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
102	EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
103	// Path to a file containing a list of ip addresses and/or networks to never ban
104	SafeListFile string `json:"safelist_file" mapstructure:"safelist_file"`
105	// Path to a file containing a list of ip addresses and/or networks to always ban
106	BlockListFile string `json:"blocklist_file" mapstructure:"blocklist_file"`
107}
108
109type memoryDefender struct {
110	config *DefenderConfig
111	sync.RWMutex
112	// IP addresses of the clients trying to connected are stored inside hosts,
113	// they are added to banned once the thresold is reached.
114	// A violation from a banned host will increase the ban time
115	// based on the configured BanTimeIncrement
116	hosts     map[string]hostScore // the key is the host IP
117	banned    map[string]time.Time // the key is the host IP
118	safeList  *HostList
119	blockList *HostList
120}
121
122// HostListFile defines the structure expected for safe/block list files
123type HostListFile struct {
124	IPAddresses  []string `json:"addresses"`
125	CIDRNetworks []string `json:"networks"`
126}
127
128// HostList defines the structure used to keep the HostListFile in memory
129type HostList struct {
130	IPAddresses map[string]bool
131	Ranges      cidranger.Ranger
132}
133
134func (h *HostList) isListed(ip string) bool {
135	if _, ok := h.IPAddresses[ip]; ok {
136		return true
137	}
138
139	ok, err := h.Ranges.Contains(net.ParseIP(ip))
140	if err != nil {
141		return false
142	}
143
144	return ok
145}
146
147type hostEvent struct {
148	dateTime time.Time
149	score    int
150}
151
152type hostScore struct {
153	TotalScore int
154	Events     []hostEvent
155}
156
157// validate returns an error if the configuration is invalid
158func (c *DefenderConfig) validate() error {
159	if !c.Enabled {
160		return nil
161	}
162	if c.ScoreInvalid >= c.Threshold {
163		return fmt.Errorf("score_invalid %v cannot be greater than threshold %v", c.ScoreInvalid, c.Threshold)
164	}
165	if c.ScoreValid >= c.Threshold {
166		return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold)
167	}
168	if c.ScoreLimitExceeded >= c.Threshold {
169		return fmt.Errorf("score_limit_exceeded %v cannot be greater than threshold %v", c.ScoreLimitExceeded, c.Threshold)
170	}
171	if c.BanTime <= 0 {
172		return fmt.Errorf("invalid ban_time %v", c.BanTime)
173	}
174	if c.BanTimeIncrement <= 0 {
175		return fmt.Errorf("invalid ban_time_increment %v", c.BanTimeIncrement)
176	}
177	if c.ObservationTime <= 0 {
178		return fmt.Errorf("invalid observation_time %v", c.ObservationTime)
179	}
180	if c.EntriesSoftLimit <= 0 {
181		return fmt.Errorf("invalid entries_soft_limit %v", c.EntriesSoftLimit)
182	}
183	if c.EntriesHardLimit <= c.EntriesSoftLimit {
184		return fmt.Errorf("invalid entries_hard_limit %v must be > %v", c.EntriesHardLimit, c.EntriesSoftLimit)
185	}
186
187	return nil
188}
189
190func newInMemoryDefender(config *DefenderConfig) (Defender, error) {
191	err := config.validate()
192	if err != nil {
193		return nil, err
194	}
195	defender := &memoryDefender{
196		config: config,
197		hosts:  make(map[string]hostScore),
198		banned: make(map[string]time.Time),
199	}
200
201	if err := defender.Reload(); err != nil {
202		return nil, err
203	}
204
205	return defender, nil
206}
207
208// Reload reloads block and safe lists
209func (d *memoryDefender) Reload() error {
210	blockList, err := loadHostListFromFile(d.config.BlockListFile)
211	if err != nil {
212		return err
213	}
214
215	d.Lock()
216	d.blockList = blockList
217	d.Unlock()
218
219	safeList, err := loadHostListFromFile(d.config.SafeListFile)
220	if err != nil {
221		return err
222	}
223
224	d.Lock()
225	d.safeList = safeList
226	d.Unlock()
227
228	return nil
229}
230
231// GetHosts returns hosts that are banned or for which some violations have been detected
232func (d *memoryDefender) GetHosts() []*DefenderEntry {
233	d.RLock()
234	defer d.RUnlock()
235
236	var result []*DefenderEntry
237	for k, v := range d.banned {
238		if v.After(time.Now()) {
239			result = append(result, &DefenderEntry{
240				IP:      k,
241				BanTime: v,
242			})
243		}
244	}
245	for k, v := range d.hosts {
246		score := 0
247		for _, event := range v.Events {
248			if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
249				score += event.score
250			}
251		}
252		if score > 0 {
253			result = append(result, &DefenderEntry{
254				IP:    k,
255				Score: score,
256			})
257		}
258	}
259
260	return result
261}
262
263// GetHost returns a defender host by ip, if any
264func (d *memoryDefender) GetHost(ip string) (*DefenderEntry, error) {
265	d.RLock()
266	defer d.RUnlock()
267
268	if banTime, ok := d.banned[ip]; ok {
269		if banTime.After(time.Now()) {
270			return &DefenderEntry{
271				IP:      ip,
272				BanTime: banTime,
273			}, nil
274		}
275	}
276
277	if hs, ok := d.hosts[ip]; ok {
278		score := 0
279		for _, event := range hs.Events {
280			if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
281				score += event.score
282			}
283		}
284		if score > 0 {
285			return &DefenderEntry{
286				IP:    ip,
287				Score: score,
288			}, nil
289		}
290	}
291
292	return nil, util.NewRecordNotFoundError("host not found")
293}
294
295// IsBanned returns true if the specified IP is banned
296// and increase ban time if the IP is found.
297// This method must be called as soon as the client connects
298func (d *memoryDefender) IsBanned(ip string) bool {
299	d.RLock()
300
301	if banTime, ok := d.banned[ip]; ok {
302		if banTime.After(time.Now()) {
303			increment := d.config.BanTime * d.config.BanTimeIncrement / 100
304			if increment == 0 {
305				increment++
306			}
307
308			d.RUnlock()
309
310			// we can save an earlier ban time if there are contemporary updates
311			// but this should not make much difference. I prefer to hold a read lock
312			// until possible for performance reasons, this method is called each
313			// time a new client connects and it must be as fast as possible
314			d.Lock()
315			d.banned[ip] = banTime.Add(time.Duration(increment) * time.Minute)
316			d.Unlock()
317
318			return true
319		}
320	}
321
322	defer d.RUnlock()
323
324	if d.blockList != nil && d.blockList.isListed(ip) {
325		// permanent ban
326		return true
327	}
328
329	return false
330}
331
332// DeleteHost removes the specified IP from the defender lists
333func (d *memoryDefender) DeleteHost(ip string) bool {
334	d.Lock()
335	defer d.Unlock()
336
337	if _, ok := d.banned[ip]; ok {
338		delete(d.banned, ip)
339		return true
340	}
341
342	if _, ok := d.hosts[ip]; ok {
343		delete(d.hosts, ip)
344		return true
345	}
346
347	return false
348}
349
350// AddEvent adds an event for the given IP.
351// This method must be called for clients not yet banned
352func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
353	d.Lock()
354	defer d.Unlock()
355
356	if d.safeList != nil && d.safeList.isListed(ip) {
357		return
358	}
359
360	// ignore events for already banned hosts
361	if v, ok := d.banned[ip]; ok {
362		if v.After(time.Now()) {
363			return
364		}
365		delete(d.banned, ip)
366	}
367
368	var score int
369
370	switch event {
371	case HostEventLoginFailed:
372		score = d.config.ScoreValid
373	case HostEventLimitExceeded:
374		score = d.config.ScoreLimitExceeded
375	case HostEventUserNotFound, HostEventNoLoginTried:
376		score = d.config.ScoreInvalid
377	}
378
379	ev := hostEvent{
380		dateTime: time.Now(),
381		score:    score,
382	}
383
384	if hs, ok := d.hosts[ip]; ok {
385		hs.Events = append(hs.Events, ev)
386		hs.TotalScore = 0
387
388		idx := 0
389		for _, event := range hs.Events {
390			if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
391				hs.Events[idx] = event
392				hs.TotalScore += event.score
393				idx++
394			}
395		}
396
397		hs.Events = hs.Events[:idx]
398		if hs.TotalScore >= d.config.Threshold {
399			d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
400			delete(d.hosts, ip)
401			d.cleanupBanned()
402		} else {
403			d.hosts[ip] = hs
404		}
405	} else {
406		d.hosts[ip] = hostScore{
407			TotalScore: ev.score,
408			Events:     []hostEvent{ev},
409		}
410		d.cleanupHosts()
411	}
412}
413
414func (d *memoryDefender) countBanned() int {
415	d.RLock()
416	defer d.RUnlock()
417
418	return len(d.banned)
419}
420
421func (d *memoryDefender) countHosts() int {
422	d.RLock()
423	defer d.RUnlock()
424
425	return len(d.hosts)
426}
427
428// GetBanTime returns the ban time for the given IP or nil if the IP is not banned
429func (d *memoryDefender) GetBanTime(ip string) *time.Time {
430	d.RLock()
431	defer d.RUnlock()
432
433	if banTime, ok := d.banned[ip]; ok {
434		return &banTime
435	}
436
437	return nil
438}
439
440// GetScore returns the score for the given IP
441func (d *memoryDefender) GetScore(ip string) int {
442	d.RLock()
443	defer d.RUnlock()
444
445	score := 0
446
447	if hs, ok := d.hosts[ip]; ok {
448		for _, event := range hs.Events {
449			if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
450				score += event.score
451			}
452		}
453	}
454
455	return score
456}
457
458func (d *memoryDefender) cleanupBanned() {
459	if len(d.banned) > d.config.EntriesHardLimit {
460		kvList := make(kvList, 0, len(d.banned))
461
462		for k, v := range d.banned {
463			if v.Before(time.Now()) {
464				delete(d.banned, k)
465			}
466
467			kvList = append(kvList, kv{
468				Key:   k,
469				Value: v.UnixNano(),
470			})
471		}
472
473		// we removed expired ip addresses, if any, above, this could be enough
474		numToRemove := len(d.banned) - d.config.EntriesSoftLimit
475
476		if numToRemove <= 0 {
477			return
478		}
479
480		sort.Sort(kvList)
481
482		for idx, kv := range kvList {
483			if idx >= numToRemove {
484				break
485			}
486
487			delete(d.banned, kv.Key)
488		}
489	}
490}
491
492func (d *memoryDefender) cleanupHosts() {
493	if len(d.hosts) > d.config.EntriesHardLimit {
494		kvList := make(kvList, 0, len(d.hosts))
495
496		for k, v := range d.hosts {
497			value := int64(0)
498			if len(v.Events) > 0 {
499				value = v.Events[len(v.Events)-1].dateTime.UnixNano()
500			}
501			kvList = append(kvList, kv{
502				Key:   k,
503				Value: value,
504			})
505		}
506
507		sort.Sort(kvList)
508
509		numToRemove := len(d.hosts) - d.config.EntriesSoftLimit
510
511		for idx, kv := range kvList {
512			if idx >= numToRemove {
513				break
514			}
515
516			delete(d.hosts, kv.Key)
517		}
518	}
519}
520
521func loadHostListFromFile(name string) (*HostList, error) {
522	if name == "" {
523		return nil, nil
524	}
525	if !util.IsFileInputValid(name) {
526		return nil, fmt.Errorf("invalid host list file name %#v", name)
527	}
528
529	info, err := os.Stat(name)
530	if err != nil {
531		return nil, err
532	}
533
534	// opinionated max size, you should avoid big host lists
535	if info.Size() > 1048576*5 { // 5MB
536		return nil, fmt.Errorf("host list file %#v is too big: %v bytes", name, info.Size())
537	}
538
539	content, err := os.ReadFile(name)
540	if err != nil {
541		return nil, fmt.Errorf("unable to read input file %#v: %v", name, err)
542	}
543
544	var hostList HostListFile
545
546	err = json.Unmarshal(content, &hostList)
547	if err != nil {
548		return nil, err
549	}
550
551	if len(hostList.CIDRNetworks) > 0 || len(hostList.IPAddresses) > 0 {
552		result := &HostList{
553			IPAddresses: make(map[string]bool),
554			Ranges:      cidranger.NewPCTrieRanger(),
555		}
556		ipCount := 0
557		cdrCount := 0
558		for _, ip := range hostList.IPAddresses {
559			if net.ParseIP(ip) == nil {
560				logger.Warn(logSender, "", "unable to parse IP %#v", ip)
561				continue
562			}
563			result.IPAddresses[ip] = true
564			ipCount++
565		}
566		for _, cidrNet := range hostList.CIDRNetworks {
567			_, network, err := net.ParseCIDR(cidrNet)
568			if err != nil {
569				logger.Warn(logSender, "", "unable to parse CIDR network %#v", cidrNet)
570				continue
571			}
572			err = result.Ranges.Insert(cidranger.NewBasicRangerEntry(*network))
573			if err == nil {
574				cdrCount++
575			}
576		}
577
578		logger.Info(logSender, "", "list %#v loaded, ip addresses loaded: %v/%v networks loaded: %v/%v",
579			name, ipCount, len(hostList.IPAddresses), cdrCount, len(hostList.CIDRNetworks))
580		return result, nil
581	}
582
583	return nil, nil
584}
585
586type kv struct {
587	Key   string
588	Value int64
589}
590
591type kvList []kv
592
593func (p kvList) Len() int           { return len(p) }
594func (p kvList) Less(i, j int) bool { return p[i].Value < p[j].Value }
595func (p kvList) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }
596