1// Package util provides some common utility methods 2package util 3 4import ( 5 "bytes" 6 "crypto/aes" 7 "crypto/cipher" 8 "crypto/ecdsa" 9 "crypto/ed25519" 10 "crypto/elliptic" 11 "crypto/rand" 12 "crypto/rsa" 13 "crypto/tls" 14 "crypto/x509" 15 "encoding/hex" 16 "encoding/pem" 17 "errors" 18 "fmt" 19 "html/template" 20 "io" 21 "net" 22 "net/http" 23 "net/url" 24 "os" 25 "path" 26 "path/filepath" 27 "runtime" 28 "strings" 29 "time" 30 31 "github.com/google/uuid" 32 "github.com/lithammer/shortuuid/v3" 33 "github.com/rs/xid" 34 "golang.org/x/crypto/ssh" 35 36 "github.com/drakkan/sftpgo/v2/logger" 37) 38 39const ( 40 logSender = "util" 41 osWindows = "windows" 42) 43 44var ( 45 xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") 46 xRealIP = http.CanonicalHeaderKey("X-Real-IP") 47 cfConnectingIP = http.CanonicalHeaderKey("CF-Connecting-IP") 48 trueClientIP = http.CanonicalHeaderKey("True-Client-IP") 49) 50 51// IsStringInSlice searches a string in a slice and returns true if the string is found 52func IsStringInSlice(obj string, list []string) bool { 53 for i := 0; i < len(list); i++ { 54 if list[i] == obj { 55 return true 56 } 57 } 58 return false 59} 60 61// IsStringPrefixInSlice searches a string prefix in a slice and returns true 62// if a matching prefix is found 63func IsStringPrefixInSlice(obj string, list []string) bool { 64 for i := 0; i < len(list); i++ { 65 if strings.HasPrefix(obj, list[i]) { 66 return true 67 } 68 } 69 return false 70} 71 72// RemoveDuplicates returns a new slice removing any duplicate element from the initial one 73func RemoveDuplicates(obj []string) []string { 74 if len(obj) == 0 { 75 return obj 76 } 77 result := make([]string, 0, len(obj)) 78 seen := make(map[string]bool) 79 for _, item := range obj { 80 if _, ok := seen[item]; !ok { 81 result = append(result, item) 82 } 83 seen[item] = true 84 } 85 return result 86} 87 88// GetTimeAsMsSinceEpoch returns unix timestamp as milliseconds from a time struct 89func GetTimeAsMsSinceEpoch(t time.Time) int64 { 90 return t.UnixNano() / 1000000 91} 92 93// GetTimeFromMsecSinceEpoch return a time struct from a unix timestamp with millisecond precision 94func GetTimeFromMsecSinceEpoch(msec int64) time.Time { 95 return time.Unix(0, msec*1000000) 96} 97 98// GetDurationAsString returns a string representation for a time.Duration 99func GetDurationAsString(d time.Duration) string { 100 d = d.Round(time.Second) 101 h := d / time.Hour 102 d -= h * time.Hour 103 m := d / time.Minute 104 d -= m * time.Minute 105 s := d / time.Second 106 if h > 0 { 107 return fmt.Sprintf("%02d:%02d:%02d", h, m, s) 108 } 109 return fmt.Sprintf("%02d:%02d", m, s) 110} 111 112// ByteCountSI returns humanized size in SI (decimal) format 113func ByteCountSI(b int64) string { 114 return byteCount(b, 1000) 115} 116 117// ByteCountIEC returns humanized size in IEC (binary) format 118func ByteCountIEC(b int64) string { 119 return byteCount(b, 1024) 120} 121 122func byteCount(b int64, unit int64) string { 123 if b < unit { 124 return fmt.Sprintf("%d B", b) 125 } 126 div, exp := unit, 0 127 for n := b / unit; n >= unit; n /= unit { 128 div *= unit 129 exp++ 130 } 131 if unit == 1000 { 132 return fmt.Sprintf("%.1f %cB", 133 float64(b)/float64(div), "KMGTPE"[exp]) 134 } 135 return fmt.Sprintf("%.1f %ciB", 136 float64(b)/float64(div), "KMGTPE"[exp]) 137} 138 139// GetIPFromRemoteAddress returns the IP from the remote address. 140// If the given remote address cannot be parsed it will be returned unchanged 141func GetIPFromRemoteAddress(remoteAddress string) string { 142 ip, _, err := net.SplitHostPort(remoteAddress) 143 if err == nil { 144 return ip 145 } 146 return remoteAddress 147} 148 149// NilIfEmpty returns nil if the input string is empty 150func NilIfEmpty(s string) *string { 151 if len(s) == 0 { 152 return nil 153 } 154 return &s 155} 156 157// EncryptData encrypts data using the given key 158func EncryptData(data string) (string, error) { 159 var result string 160 key := make([]byte, 16) 161 if _, err := io.ReadFull(rand.Reader, key); err != nil { 162 return result, err 163 } 164 keyHex := hex.EncodeToString(key) 165 block, err := aes.NewCipher([]byte(keyHex)) 166 if err != nil { 167 return result, err 168 } 169 gcm, err := cipher.NewGCM(block) 170 if err != nil { 171 return result, err 172 } 173 nonce := make([]byte, gcm.NonceSize()) 174 if _, err = io.ReadFull(rand.Reader, nonce); err != nil { 175 return result, err 176 } 177 ciphertext := gcm.Seal(nonce, nonce, []byte(data), nil) 178 result = fmt.Sprintf("$aes$%s$%x", keyHex, ciphertext) 179 return result, err 180} 181 182// RemoveDecryptionKey returns encrypted data without the decryption key 183func RemoveDecryptionKey(encryptData string) string { 184 vals := strings.Split(encryptData, "$") 185 if len(vals) == 4 { 186 return fmt.Sprintf("$%v$%v", vals[1], vals[3]) 187 } 188 return encryptData 189} 190 191// DecryptData decrypts data encrypted using EncryptData 192func DecryptData(data string) (string, error) { 193 var result string 194 vals := strings.Split(data, "$") 195 if len(vals) != 4 { 196 return "", errors.New("data to decrypt is not in the correct format") 197 } 198 key := vals[2] 199 encrypted, err := hex.DecodeString(vals[3]) 200 if err != nil { 201 return result, err 202 } 203 block, err := aes.NewCipher([]byte(key)) 204 if err != nil { 205 return result, err 206 } 207 gcm, err := cipher.NewGCM(block) 208 if err != nil { 209 return result, err 210 } 211 nonceSize := gcm.NonceSize() 212 if len(encrypted) < nonceSize { 213 return result, errors.New("malformed ciphertext") 214 } 215 nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:] 216 plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) 217 if err != nil { 218 return result, err 219 } 220 return string(plaintext), nil 221} 222 223// GenerateRSAKeys generate rsa private and public keys and write the 224// private key to specified file and the public key to the specified 225// file adding the .pub suffix 226func GenerateRSAKeys(file string) error { 227 if err := createDirPathIfMissing(file, 0700); err != nil { 228 return err 229 } 230 key, err := rsa.GenerateKey(rand.Reader, 4096) 231 if err != nil { 232 return err 233 } 234 235 o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) 236 if err != nil { 237 return err 238 } 239 defer o.Close() 240 241 priv := &pem.Block{ 242 Type: "RSA PRIVATE KEY", 243 Bytes: x509.MarshalPKCS1PrivateKey(key), 244 } 245 246 if err := pem.Encode(o, priv); err != nil { 247 return err 248 } 249 250 pub, err := ssh.NewPublicKey(&key.PublicKey) 251 if err != nil { 252 return err 253 } 254 return os.WriteFile(file+".pub", ssh.MarshalAuthorizedKey(pub), 0600) 255} 256 257// GenerateECDSAKeys generate ecdsa private and public keys and write the 258// private key to specified file and the public key to the specified 259// file adding the .pub suffix 260func GenerateECDSAKeys(file string) error { 261 if err := createDirPathIfMissing(file, 0700); err != nil { 262 return err 263 } 264 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 265 if err != nil { 266 return err 267 } 268 269 keyBytes, err := x509.MarshalECPrivateKey(key) 270 if err != nil { 271 return err 272 } 273 priv := &pem.Block{ 274 Type: "EC PRIVATE KEY", 275 Bytes: keyBytes, 276 } 277 278 o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) 279 if err != nil { 280 return err 281 } 282 defer o.Close() 283 284 if err := pem.Encode(o, priv); err != nil { 285 return err 286 } 287 288 pub, err := ssh.NewPublicKey(&key.PublicKey) 289 if err != nil { 290 return err 291 } 292 return os.WriteFile(file+".pub", ssh.MarshalAuthorizedKey(pub), 0600) 293} 294 295// GenerateEd25519Keys generate ed25519 private and public keys and write the 296// private key to specified file and the public key to the specified 297// file adding the .pub suffix 298func GenerateEd25519Keys(file string) error { 299 pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) 300 if err != nil { 301 return err 302 } 303 keyBytes, err := x509.MarshalPKCS8PrivateKey(privKey) 304 if err != nil { 305 return err 306 } 307 priv := &pem.Block{ 308 Type: "PRIVATE KEY", 309 Bytes: keyBytes, 310 } 311 o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) 312 if err != nil { 313 return err 314 } 315 defer o.Close() 316 317 if err := pem.Encode(o, priv); err != nil { 318 return err 319 } 320 pub, err := ssh.NewPublicKey(pubKey) 321 if err != nil { 322 return err 323 } 324 return os.WriteFile(file+".pub", ssh.MarshalAuthorizedKey(pub), 0600) 325} 326 327// GetDirsForVirtualPath returns all the directory for the given path in reverse order 328// for example if the path is: /1/2/3/4 it returns: 329// [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ] 330func GetDirsForVirtualPath(virtualPath string) []string { 331 if virtualPath == "." { 332 virtualPath = "/" 333 } else { 334 if !path.IsAbs(virtualPath) { 335 virtualPath = CleanPath(virtualPath) 336 } 337 } 338 dirsForPath := []string{virtualPath} 339 for { 340 if virtualPath == "/" { 341 break 342 } 343 virtualPath = path.Dir(virtualPath) 344 dirsForPath = append(dirsForPath, virtualPath) 345 } 346 return dirsForPath 347} 348 349// CleanPath returns a clean POSIX (/) absolute path to work with 350func CleanPath(p string) string { 351 p = filepath.ToSlash(p) 352 if !path.IsAbs(p) { 353 p = "/" + p 354 } 355 return path.Clean(p) 356} 357 358// LoadTemplate parses the given template paths. 359// It behaves like template.Must but it writes a log before exiting. 360// You can optionally provide a base template (e.g. to define some custom functions) 361func LoadTemplate(base *template.Template, paths ...string) *template.Template { 362 var t *template.Template 363 var err error 364 365 if base != nil { 366 base, err = base.Clone() 367 if err == nil { 368 t, err = base.ParseFiles(paths...) 369 } 370 } else { 371 t, err = template.ParseFiles(paths...) 372 } 373 374 if err != nil { 375 logger.ErrorToConsole("error loading required template: %v", err) 376 logger.Error(logSender, "", "error loading required template: %v", err) 377 panic(err) 378 } 379 return t 380} 381 382// IsFileInputValid returns true this is a valid file name. 383// This method must be used before joining a file name, generally provided as 384// user input, with a directory 385func IsFileInputValid(fileInput string) bool { 386 cleanInput := filepath.Clean(fileInput) 387 if cleanInput == "." || cleanInput == ".." { 388 return false 389 } 390 return true 391} 392 393// CleanDirInput sanitizes user input for directories. 394// On Windows it removes any trailing `"`. 395// We try to help windows users that set an invalid path such as "C:\ProgramData\SFTPGO\". 396// This will only help if the invalid path is the last argument, for example in this command: 397// sftpgo.exe serve -c "C:\ProgramData\SFTPGO\" -l "sftpgo.log" 398// the -l flag will be ignored and the -c flag will get the value `C:\ProgramData\SFTPGO" -l sftpgo.log` 399// since the backslash after SFTPGO escape the double quote. This is definitely a bad user input 400func CleanDirInput(dirInput string) string { 401 if runtime.GOOS == osWindows { 402 for strings.HasSuffix(dirInput, "\"") { 403 dirInput = strings.TrimSuffix(dirInput, "\"") 404 } 405 } 406 return filepath.Clean(dirInput) 407} 408 409func createDirPathIfMissing(file string, perm os.FileMode) error { 410 dirPath := filepath.Dir(file) 411 if _, err := os.Stat(dirPath); os.IsNotExist(err) { 412 err = os.MkdirAll(dirPath, perm) 413 if err != nil { 414 return err 415 } 416 } 417 return nil 418} 419 420// GenerateRandomBytes generates the secret to use for JWT auth 421func GenerateRandomBytes(length int) []byte { 422 b := make([]byte, length) 423 _, err := io.ReadFull(rand.Reader, b) 424 if err == nil { 425 return b 426 } 427 428 b = xid.New().Bytes() 429 for len(b) < length { 430 b = append(b, xid.New().Bytes()...) 431 } 432 433 return b[:length] 434} 435 436// GenerateUniqueID retuens an unique ID 437func GenerateUniqueID() string { 438 u, err := uuid.NewRandom() 439 if err != nil { 440 return xid.New().String() 441 } 442 return shortuuid.DefaultEncoder.Encode(u) 443} 444 445// HTTPListenAndServe is a wrapper for ListenAndServe that support both tcp 446// and Unix-domain sockets 447func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool, logSender string) error { 448 var listener net.Listener 449 var err error 450 451 if filepath.IsAbs(address) && runtime.GOOS != osWindows { 452 if !IsFileInputValid(address) { 453 return fmt.Errorf("invalid socket address %#v", address) 454 } 455 err = createDirPathIfMissing(address, os.ModePerm) 456 if err != nil { 457 logger.ErrorToConsole("error creating Unix-domain socket parent dir: %v", err) 458 logger.Error(logSender, "", "error creating Unix-domain socket parent dir: %v", err) 459 } 460 os.Remove(address) 461 listener, err = newListener("unix", address, srv.ReadTimeout, srv.WriteTimeout) 462 } else { 463 CheckTCP4Port(port) 464 listener, err = newListener("tcp", fmt.Sprintf("%s:%d", address, port), srv.ReadTimeout, srv.WriteTimeout) 465 } 466 if err != nil { 467 return err 468 } 469 470 logger.Info(logSender, "", "server listener registered, address: %v TLS enabled: %v", listener.Addr().String(), isTLS) 471 472 defer listener.Close() 473 474 if isTLS { 475 return srv.ServeTLS(listener, "", "") 476 } 477 return srv.Serve(listener) 478} 479 480// GetTLSCiphersFromNames returns the TLS ciphers from the specified names 481func GetTLSCiphersFromNames(cipherNames []string) []uint16 { 482 var ciphers []uint16 483 484 for _, name := range RemoveDuplicates(cipherNames) { 485 for _, c := range tls.CipherSuites() { 486 if c.Name == strings.TrimSpace(name) { 487 ciphers = append(ciphers, c.ID) 488 } 489 } 490 } 491 492 return ciphers 493} 494 495// EncodeTLSCertToPem returns the specified certificate PEM encoded. 496// This can be verified using openssl x509 -in cert.crt -text -noout 497func EncodeTLSCertToPem(tlsCert *x509.Certificate) (string, error) { 498 if len(tlsCert.Raw) == 0 { 499 return "", errors.New("invalid x509 certificate, no der contents") 500 } 501 publicKeyBlock := pem.Block{ 502 Type: "CERTIFICATE", 503 Bytes: tlsCert.Raw, 504 } 505 return string(pem.EncodeToMemory(&publicKeyBlock)), nil 506} 507 508// CheckTCP4Port quits the app if bind on the given IPv4 port fails. 509// This is a ugly hack to avoid to bind on an already used port. 510// It is required on Windows only. Upstream does not consider this 511// behaviour a bug: 512// https://github.com/golang/go/issues/45150 513func CheckTCP4Port(port int) { 514 if runtime.GOOS != osWindows { 515 return 516 } 517 listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", port)) 518 if err != nil { 519 logger.ErrorToConsole("unable to bind on tcp4 address: %v", err) 520 logger.Error(logSender, "", "unable to bind on tcp4 address: %v", err) 521 os.Exit(1) 522 } 523 listener.Close() 524} 525 526// IsByteArrayEmpty return true if the byte array is empty or a new line 527func IsByteArrayEmpty(b []byte) bool { 528 if len(b) == 0 { 529 return true 530 } 531 if bytes.Equal(b, []byte("\n")) { 532 return true 533 } 534 if bytes.Equal(b, []byte("\r\n")) { 535 return true 536 } 537 return false 538} 539 540// GetSSHPublicKeyAsString returns an SSH public key serialized as string 541func GetSSHPublicKeyAsString(pubKey []byte) (string, error) { 542 if len(pubKey) == 0 { 543 return "", nil 544 } 545 k, err := ssh.ParsePublicKey(pubKey) 546 if err != nil { 547 return "", err 548 } 549 return string(ssh.MarshalAuthorizedKey(k)), nil 550} 551 552// GetRealIP returns the ip address as result of parsing either the 553// X-Real-IP header or the X-Forwarded-For header 554func GetRealIP(r *http.Request) string { 555 var ip string 556 557 if clientIP := r.Header.Get(trueClientIP); clientIP != "" { 558 ip = clientIP 559 } else if xrip := r.Header.Get(xRealIP); xrip != "" { 560 ip = xrip 561 } else if clientIP := r.Header.Get(cfConnectingIP); clientIP != "" { 562 ip = clientIP 563 } else if xff := r.Header.Get(xForwardedFor); xff != "" { 564 i := strings.Index(xff, ", ") 565 if i == -1 { 566 i = len(xff) 567 } 568 ip = strings.TrimSpace(xff[:i]) 569 } 570 if net.ParseIP(ip) == nil { 571 return "" 572 } 573 574 return ip 575} 576 577// GetHTTPLocalAddress returns the local address for an http.Request 578// or empty if it cannot be determined 579func GetHTTPLocalAddress(r *http.Request) string { 580 if r == nil { 581 return "" 582 } 583 localAddr, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr) 584 if ok { 585 return localAddr.String() 586 } 587 return "" 588} 589 590// ParseAllowedIPAndRanges returns a list of functions that allow to find if an 591// IP is equal or is contained within the allowed list 592func ParseAllowedIPAndRanges(allowed []string) ([]func(net.IP) bool, error) { 593 res := make([]func(net.IP) bool, len(allowed)) 594 for i, allowFrom := range allowed { 595 if strings.LastIndex(allowFrom, "/") > 0 { 596 _, ipRange, err := net.ParseCIDR(allowFrom) 597 if err != nil { 598 return nil, fmt.Errorf("given string %q is not a valid IP range: %v", allowFrom, err) 599 } 600 601 res[i] = ipRange.Contains 602 } else { 603 allowed := net.ParseIP(allowFrom) 604 if allowed == nil { 605 return nil, fmt.Errorf("given string %q is not a valid IP address", allowFrom) 606 } 607 608 res[i] = allowed.Equal 609 } 610 } 611 612 return res, nil 613} 614 615// GetRedactedURL returns the url redacting the password if any 616func GetRedactedURL(rawurl string) string { 617 if !strings.HasPrefix(rawurl, "http") { 618 return rawurl 619 } 620 u, err := url.Parse(rawurl) 621 if err != nil { 622 return rawurl 623 } 624 return u.Redacted() 625} 626 627// PrependFileInfo prepends a file info to a slice in an efficient way. 628// We, optimistically, assume that the slice has enough capacity 629func PrependFileInfo(files []os.FileInfo, info os.FileInfo) []os.FileInfo { 630 files = append(files, nil) 631 copy(files[1:], files) 632 files[0] = info 633 return files 634} 635