1package common 2 3import ( 4 "encoding/hex" 5 "encoding/json" 6 "fmt" 7 "net" 8 "os" 9 "sort" 10 "sync" 11 "time" 12 13 "github.com/yl2chen/cidranger" 14 15 "github.com/drakkan/sftpgo/v2/logger" 16 "github.com/drakkan/sftpgo/v2/util" 17) 18 19// HostEvent is the enumerable for the supported host events 20type HostEvent int 21 22// Supported host events 23const ( 24 HostEventLoginFailed HostEvent = iota 25 HostEventUserNotFound 26 HostEventNoLoginTried 27 HostEventLimitExceeded 28) 29 30// DefenderEntry defines a defender entry 31type DefenderEntry struct { 32 IP string `json:"ip"` 33 Score int `json:"score,omitempty"` 34 BanTime time.Time `json:"ban_time,omitempty"` 35} 36 37// GetID returns an unique ID for a defender entry 38func (d *DefenderEntry) GetID() string { 39 return hex.EncodeToString([]byte(d.IP)) 40} 41 42// GetBanTime returns the ban time for a defender entry as string 43func (d *DefenderEntry) GetBanTime() string { 44 if d.BanTime.IsZero() { 45 return "" 46 } 47 return d.BanTime.UTC().Format(time.RFC3339) 48} 49 50// MarshalJSON returns the JSON encoding of a DefenderEntry. 51func (d *DefenderEntry) MarshalJSON() ([]byte, error) { 52 return json.Marshal(&struct { 53 ID string `json:"id"` 54 IP string `json:"ip"` 55 Score int `json:"score,omitempty"` 56 BanTime string `json:"ban_time,omitempty"` 57 }{ 58 ID: d.GetID(), 59 IP: d.IP, 60 Score: d.Score, 61 BanTime: d.GetBanTime(), 62 }) 63} 64 65// Defender defines the interface that a defender must implements 66type Defender interface { 67 GetHosts() []*DefenderEntry 68 GetHost(ip string) (*DefenderEntry, error) 69 AddEvent(ip string, event HostEvent) 70 IsBanned(ip string) bool 71 GetBanTime(ip string) *time.Time 72 GetScore(ip string) int 73 DeleteHost(ip string) bool 74 Reload() error 75} 76 77// DefenderConfig defines the "defender" configuration 78type DefenderConfig struct { 79 // Set to true to enable the defender 80 Enabled bool `json:"enabled" mapstructure:"enabled"` 81 // BanTime is the number of minutes that a host is banned 82 BanTime int `json:"ban_time" mapstructure:"ban_time"` 83 // Percentage increase of the ban time if a banned host tries to connect again 84 BanTimeIncrement int `json:"ban_time_increment" mapstructure:"ban_time_increment"` 85 // Threshold value for banning a client 86 Threshold int `json:"threshold" mapstructure:"threshold"` 87 // Score for invalid login attempts, eg. non-existent user accounts or 88 // client disconnected for inactivity without authentication attempts 89 ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"` 90 // Score for valid login attempts, eg. user accounts that exist 91 ScoreValid int `json:"score_valid" mapstructure:"score_valid"` 92 // Score for limit exceeded events, generated from the rate limiters or for max connections 93 // per-host exceeded 94 ScoreLimitExceeded int `json:"score_limit_exceeded" mapstructure:"score_limit_exceeded"` 95 // Defines the time window, in minutes, for tracking client errors. 96 // A host is banned if it has exceeded the defined threshold during 97 // the last observation time minutes 98 ObservationTime int `json:"observation_time" mapstructure:"observation_time"` 99 // The number of banned IPs and host scores kept in memory will vary between the 100 // soft and hard limit 101 EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"` 102 EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"` 103 // Path to a file containing a list of ip addresses and/or networks to never ban 104 SafeListFile string `json:"safelist_file" mapstructure:"safelist_file"` 105 // Path to a file containing a list of ip addresses and/or networks to always ban 106 BlockListFile string `json:"blocklist_file" mapstructure:"blocklist_file"` 107} 108 109type memoryDefender struct { 110 config *DefenderConfig 111 sync.RWMutex 112 // IP addresses of the clients trying to connected are stored inside hosts, 113 // they are added to banned once the thresold is reached. 114 // A violation from a banned host will increase the ban time 115 // based on the configured BanTimeIncrement 116 hosts map[string]hostScore // the key is the host IP 117 banned map[string]time.Time // the key is the host IP 118 safeList *HostList 119 blockList *HostList 120} 121 122// HostListFile defines the structure expected for safe/block list files 123type HostListFile struct { 124 IPAddresses []string `json:"addresses"` 125 CIDRNetworks []string `json:"networks"` 126} 127 128// HostList defines the structure used to keep the HostListFile in memory 129type HostList struct { 130 IPAddresses map[string]bool 131 Ranges cidranger.Ranger 132} 133 134func (h *HostList) isListed(ip string) bool { 135 if _, ok := h.IPAddresses[ip]; ok { 136 return true 137 } 138 139 ok, err := h.Ranges.Contains(net.ParseIP(ip)) 140 if err != nil { 141 return false 142 } 143 144 return ok 145} 146 147type hostEvent struct { 148 dateTime time.Time 149 score int 150} 151 152type hostScore struct { 153 TotalScore int 154 Events []hostEvent 155} 156 157// validate returns an error if the configuration is invalid 158func (c *DefenderConfig) validate() error { 159 if !c.Enabled { 160 return nil 161 } 162 if c.ScoreInvalid >= c.Threshold { 163 return fmt.Errorf("score_invalid %v cannot be greater than threshold %v", c.ScoreInvalid, c.Threshold) 164 } 165 if c.ScoreValid >= c.Threshold { 166 return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold) 167 } 168 if c.ScoreLimitExceeded >= c.Threshold { 169 return fmt.Errorf("score_limit_exceeded %v cannot be greater than threshold %v", c.ScoreLimitExceeded, c.Threshold) 170 } 171 if c.BanTime <= 0 { 172 return fmt.Errorf("invalid ban_time %v", c.BanTime) 173 } 174 if c.BanTimeIncrement <= 0 { 175 return fmt.Errorf("invalid ban_time_increment %v", c.BanTimeIncrement) 176 } 177 if c.ObservationTime <= 0 { 178 return fmt.Errorf("invalid observation_time %v", c.ObservationTime) 179 } 180 if c.EntriesSoftLimit <= 0 { 181 return fmt.Errorf("invalid entries_soft_limit %v", c.EntriesSoftLimit) 182 } 183 if c.EntriesHardLimit <= c.EntriesSoftLimit { 184 return fmt.Errorf("invalid entries_hard_limit %v must be > %v", c.EntriesHardLimit, c.EntriesSoftLimit) 185 } 186 187 return nil 188} 189 190func newInMemoryDefender(config *DefenderConfig) (Defender, error) { 191 err := config.validate() 192 if err != nil { 193 return nil, err 194 } 195 defender := &memoryDefender{ 196 config: config, 197 hosts: make(map[string]hostScore), 198 banned: make(map[string]time.Time), 199 } 200 201 if err := defender.Reload(); err != nil { 202 return nil, err 203 } 204 205 return defender, nil 206} 207 208// Reload reloads block and safe lists 209func (d *memoryDefender) Reload() error { 210 blockList, err := loadHostListFromFile(d.config.BlockListFile) 211 if err != nil { 212 return err 213 } 214 215 d.Lock() 216 d.blockList = blockList 217 d.Unlock() 218 219 safeList, err := loadHostListFromFile(d.config.SafeListFile) 220 if err != nil { 221 return err 222 } 223 224 d.Lock() 225 d.safeList = safeList 226 d.Unlock() 227 228 return nil 229} 230 231// GetHosts returns hosts that are banned or for which some violations have been detected 232func (d *memoryDefender) GetHosts() []*DefenderEntry { 233 d.RLock() 234 defer d.RUnlock() 235 236 var result []*DefenderEntry 237 for k, v := range d.banned { 238 if v.After(time.Now()) { 239 result = append(result, &DefenderEntry{ 240 IP: k, 241 BanTime: v, 242 }) 243 } 244 } 245 for k, v := range d.hosts { 246 score := 0 247 for _, event := range v.Events { 248 if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { 249 score += event.score 250 } 251 } 252 if score > 0 { 253 result = append(result, &DefenderEntry{ 254 IP: k, 255 Score: score, 256 }) 257 } 258 } 259 260 return result 261} 262 263// GetHost returns a defender host by ip, if any 264func (d *memoryDefender) GetHost(ip string) (*DefenderEntry, error) { 265 d.RLock() 266 defer d.RUnlock() 267 268 if banTime, ok := d.banned[ip]; ok { 269 if banTime.After(time.Now()) { 270 return &DefenderEntry{ 271 IP: ip, 272 BanTime: banTime, 273 }, nil 274 } 275 } 276 277 if hs, ok := d.hosts[ip]; ok { 278 score := 0 279 for _, event := range hs.Events { 280 if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { 281 score += event.score 282 } 283 } 284 if score > 0 { 285 return &DefenderEntry{ 286 IP: ip, 287 Score: score, 288 }, nil 289 } 290 } 291 292 return nil, util.NewRecordNotFoundError("host not found") 293} 294 295// IsBanned returns true if the specified IP is banned 296// and increase ban time if the IP is found. 297// This method must be called as soon as the client connects 298func (d *memoryDefender) IsBanned(ip string) bool { 299 d.RLock() 300 301 if banTime, ok := d.banned[ip]; ok { 302 if banTime.After(time.Now()) { 303 increment := d.config.BanTime * d.config.BanTimeIncrement / 100 304 if increment == 0 { 305 increment++ 306 } 307 308 d.RUnlock() 309 310 // we can save an earlier ban time if there are contemporary updates 311 // but this should not make much difference. I prefer to hold a read lock 312 // until possible for performance reasons, this method is called each 313 // time a new client connects and it must be as fast as possible 314 d.Lock() 315 d.banned[ip] = banTime.Add(time.Duration(increment) * time.Minute) 316 d.Unlock() 317 318 return true 319 } 320 } 321 322 defer d.RUnlock() 323 324 if d.blockList != nil && d.blockList.isListed(ip) { 325 // permanent ban 326 return true 327 } 328 329 return false 330} 331 332// DeleteHost removes the specified IP from the defender lists 333func (d *memoryDefender) DeleteHost(ip string) bool { 334 d.Lock() 335 defer d.Unlock() 336 337 if _, ok := d.banned[ip]; ok { 338 delete(d.banned, ip) 339 return true 340 } 341 342 if _, ok := d.hosts[ip]; ok { 343 delete(d.hosts, ip) 344 return true 345 } 346 347 return false 348} 349 350// AddEvent adds an event for the given IP. 351// This method must be called for clients not yet banned 352func (d *memoryDefender) AddEvent(ip string, event HostEvent) { 353 d.Lock() 354 defer d.Unlock() 355 356 if d.safeList != nil && d.safeList.isListed(ip) { 357 return 358 } 359 360 // ignore events for already banned hosts 361 if v, ok := d.banned[ip]; ok { 362 if v.After(time.Now()) { 363 return 364 } 365 delete(d.banned, ip) 366 } 367 368 var score int 369 370 switch event { 371 case HostEventLoginFailed: 372 score = d.config.ScoreValid 373 case HostEventLimitExceeded: 374 score = d.config.ScoreLimitExceeded 375 case HostEventUserNotFound, HostEventNoLoginTried: 376 score = d.config.ScoreInvalid 377 } 378 379 ev := hostEvent{ 380 dateTime: time.Now(), 381 score: score, 382 } 383 384 if hs, ok := d.hosts[ip]; ok { 385 hs.Events = append(hs.Events, ev) 386 hs.TotalScore = 0 387 388 idx := 0 389 for _, event := range hs.Events { 390 if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { 391 hs.Events[idx] = event 392 hs.TotalScore += event.score 393 idx++ 394 } 395 } 396 397 hs.Events = hs.Events[:idx] 398 if hs.TotalScore >= d.config.Threshold { 399 d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute) 400 delete(d.hosts, ip) 401 d.cleanupBanned() 402 } else { 403 d.hosts[ip] = hs 404 } 405 } else { 406 d.hosts[ip] = hostScore{ 407 TotalScore: ev.score, 408 Events: []hostEvent{ev}, 409 } 410 d.cleanupHosts() 411 } 412} 413 414func (d *memoryDefender) countBanned() int { 415 d.RLock() 416 defer d.RUnlock() 417 418 return len(d.banned) 419} 420 421func (d *memoryDefender) countHosts() int { 422 d.RLock() 423 defer d.RUnlock() 424 425 return len(d.hosts) 426} 427 428// GetBanTime returns the ban time for the given IP or nil if the IP is not banned 429func (d *memoryDefender) GetBanTime(ip string) *time.Time { 430 d.RLock() 431 defer d.RUnlock() 432 433 if banTime, ok := d.banned[ip]; ok { 434 return &banTime 435 } 436 437 return nil 438} 439 440// GetScore returns the score for the given IP 441func (d *memoryDefender) GetScore(ip string) int { 442 d.RLock() 443 defer d.RUnlock() 444 445 score := 0 446 447 if hs, ok := d.hosts[ip]; ok { 448 for _, event := range hs.Events { 449 if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { 450 score += event.score 451 } 452 } 453 } 454 455 return score 456} 457 458func (d *memoryDefender) cleanupBanned() { 459 if len(d.banned) > d.config.EntriesHardLimit { 460 kvList := make(kvList, 0, len(d.banned)) 461 462 for k, v := range d.banned { 463 if v.Before(time.Now()) { 464 delete(d.banned, k) 465 } 466 467 kvList = append(kvList, kv{ 468 Key: k, 469 Value: v.UnixNano(), 470 }) 471 } 472 473 // we removed expired ip addresses, if any, above, this could be enough 474 numToRemove := len(d.banned) - d.config.EntriesSoftLimit 475 476 if numToRemove <= 0 { 477 return 478 } 479 480 sort.Sort(kvList) 481 482 for idx, kv := range kvList { 483 if idx >= numToRemove { 484 break 485 } 486 487 delete(d.banned, kv.Key) 488 } 489 } 490} 491 492func (d *memoryDefender) cleanupHosts() { 493 if len(d.hosts) > d.config.EntriesHardLimit { 494 kvList := make(kvList, 0, len(d.hosts)) 495 496 for k, v := range d.hosts { 497 value := int64(0) 498 if len(v.Events) > 0 { 499 value = v.Events[len(v.Events)-1].dateTime.UnixNano() 500 } 501 kvList = append(kvList, kv{ 502 Key: k, 503 Value: value, 504 }) 505 } 506 507 sort.Sort(kvList) 508 509 numToRemove := len(d.hosts) - d.config.EntriesSoftLimit 510 511 for idx, kv := range kvList { 512 if idx >= numToRemove { 513 break 514 } 515 516 delete(d.hosts, kv.Key) 517 } 518 } 519} 520 521func loadHostListFromFile(name string) (*HostList, error) { 522 if name == "" { 523 return nil, nil 524 } 525 if !util.IsFileInputValid(name) { 526 return nil, fmt.Errorf("invalid host list file name %#v", name) 527 } 528 529 info, err := os.Stat(name) 530 if err != nil { 531 return nil, err 532 } 533 534 // opinionated max size, you should avoid big host lists 535 if info.Size() > 1048576*5 { // 5MB 536 return nil, fmt.Errorf("host list file %#v is too big: %v bytes", name, info.Size()) 537 } 538 539 content, err := os.ReadFile(name) 540 if err != nil { 541 return nil, fmt.Errorf("unable to read input file %#v: %v", name, err) 542 } 543 544 var hostList HostListFile 545 546 err = json.Unmarshal(content, &hostList) 547 if err != nil { 548 return nil, err 549 } 550 551 if len(hostList.CIDRNetworks) > 0 || len(hostList.IPAddresses) > 0 { 552 result := &HostList{ 553 IPAddresses: make(map[string]bool), 554 Ranges: cidranger.NewPCTrieRanger(), 555 } 556 ipCount := 0 557 cdrCount := 0 558 for _, ip := range hostList.IPAddresses { 559 if net.ParseIP(ip) == nil { 560 logger.Warn(logSender, "", "unable to parse IP %#v", ip) 561 continue 562 } 563 result.IPAddresses[ip] = true 564 ipCount++ 565 } 566 for _, cidrNet := range hostList.CIDRNetworks { 567 _, network, err := net.ParseCIDR(cidrNet) 568 if err != nil { 569 logger.Warn(logSender, "", "unable to parse CIDR network %#v", cidrNet) 570 continue 571 } 572 err = result.Ranges.Insert(cidranger.NewBasicRangerEntry(*network)) 573 if err == nil { 574 cdrCount++ 575 } 576 } 577 578 logger.Info(logSender, "", "list %#v loaded, ip addresses loaded: %v/%v networks loaded: %v/%v", 579 name, ipCount, len(hostList.IPAddresses), cdrCount, len(hostList.CIDRNetworks)) 580 return result, nil 581 } 582 583 return nil, nil 584} 585 586type kv struct { 587 Key string 588 Value int64 589} 590 591type kvList []kv 592 593func (p kvList) Len() int { return len(p) } 594func (p kvList) Less(i, j int) bool { return p[i].Value < p[j].Value } 595func (p kvList) Swap(i, j int) { p[i], p[j] = p[j], p[i] } 596