1// Package util provides some common utility methods
2package util
3
4import (
5	"bytes"
6	"crypto/aes"
7	"crypto/cipher"
8	"crypto/ecdsa"
9	"crypto/ed25519"
10	"crypto/elliptic"
11	"crypto/rand"
12	"crypto/rsa"
13	"crypto/tls"
14	"crypto/x509"
15	"encoding/hex"
16	"encoding/pem"
17	"errors"
18	"fmt"
19	"html/template"
20	"io"
21	"net"
22	"net/http"
23	"net/url"
24	"os"
25	"path"
26	"path/filepath"
27	"runtime"
28	"strings"
29	"time"
30
31	"github.com/google/uuid"
32	"github.com/lithammer/shortuuid/v3"
33	"github.com/rs/xid"
34	"golang.org/x/crypto/ssh"
35
36	"github.com/drakkan/sftpgo/v2/logger"
37)
38
39const (
40	logSender = "util"
41	osWindows = "windows"
42)
43
44var (
45	xForwardedFor  = http.CanonicalHeaderKey("X-Forwarded-For")
46	xRealIP        = http.CanonicalHeaderKey("X-Real-IP")
47	cfConnectingIP = http.CanonicalHeaderKey("CF-Connecting-IP")
48	trueClientIP   = http.CanonicalHeaderKey("True-Client-IP")
49)
50
51// IsStringInSlice searches a string in a slice and returns true if the string is found
52func IsStringInSlice(obj string, list []string) bool {
53	for i := 0; i < len(list); i++ {
54		if list[i] == obj {
55			return true
56		}
57	}
58	return false
59}
60
61// IsStringPrefixInSlice searches a string prefix in a slice and returns true
62// if a matching prefix is found
63func IsStringPrefixInSlice(obj string, list []string) bool {
64	for i := 0; i < len(list); i++ {
65		if strings.HasPrefix(obj, list[i]) {
66			return true
67		}
68	}
69	return false
70}
71
72// RemoveDuplicates returns a new slice removing any duplicate element from the initial one
73func RemoveDuplicates(obj []string) []string {
74	if len(obj) == 0 {
75		return obj
76	}
77	result := make([]string, 0, len(obj))
78	seen := make(map[string]bool)
79	for _, item := range obj {
80		if _, ok := seen[item]; !ok {
81			result = append(result, item)
82		}
83		seen[item] = true
84	}
85	return result
86}
87
88// GetTimeAsMsSinceEpoch returns unix timestamp as milliseconds from a time struct
89func GetTimeAsMsSinceEpoch(t time.Time) int64 {
90	return t.UnixNano() / 1000000
91}
92
93// GetTimeFromMsecSinceEpoch return a time struct from a unix timestamp with millisecond precision
94func GetTimeFromMsecSinceEpoch(msec int64) time.Time {
95	return time.Unix(0, msec*1000000)
96}
97
98// GetDurationAsString returns a string representation for a time.Duration
99func GetDurationAsString(d time.Duration) string {
100	d = d.Round(time.Second)
101	h := d / time.Hour
102	d -= h * time.Hour
103	m := d / time.Minute
104	d -= m * time.Minute
105	s := d / time.Second
106	if h > 0 {
107		return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
108	}
109	return fmt.Sprintf("%02d:%02d", m, s)
110}
111
112// ByteCountSI returns humanized size in SI (decimal) format
113func ByteCountSI(b int64) string {
114	return byteCount(b, 1000)
115}
116
117// ByteCountIEC returns humanized size in IEC (binary) format
118func ByteCountIEC(b int64) string {
119	return byteCount(b, 1024)
120}
121
122func byteCount(b int64, unit int64) string {
123	if b < unit {
124		return fmt.Sprintf("%d B", b)
125	}
126	div, exp := unit, 0
127	for n := b / unit; n >= unit; n /= unit {
128		div *= unit
129		exp++
130	}
131	if unit == 1000 {
132		return fmt.Sprintf("%.1f %cB",
133			float64(b)/float64(div), "KMGTPE"[exp])
134	}
135	return fmt.Sprintf("%.1f %ciB",
136		float64(b)/float64(div), "KMGTPE"[exp])
137}
138
139// GetIPFromRemoteAddress returns the IP from the remote address.
140// If the given remote address cannot be parsed it will be returned unchanged
141func GetIPFromRemoteAddress(remoteAddress string) string {
142	ip, _, err := net.SplitHostPort(remoteAddress)
143	if err == nil {
144		return ip
145	}
146	return remoteAddress
147}
148
149// NilIfEmpty returns nil if the input string is empty
150func NilIfEmpty(s string) *string {
151	if len(s) == 0 {
152		return nil
153	}
154	return &s
155}
156
157// EncryptData encrypts data using the given key
158func EncryptData(data string) (string, error) {
159	var result string
160	key := make([]byte, 16)
161	if _, err := io.ReadFull(rand.Reader, key); err != nil {
162		return result, err
163	}
164	keyHex := hex.EncodeToString(key)
165	block, err := aes.NewCipher([]byte(keyHex))
166	if err != nil {
167		return result, err
168	}
169	gcm, err := cipher.NewGCM(block)
170	if err != nil {
171		return result, err
172	}
173	nonce := make([]byte, gcm.NonceSize())
174	if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
175		return result, err
176	}
177	ciphertext := gcm.Seal(nonce, nonce, []byte(data), nil)
178	result = fmt.Sprintf("$aes$%s$%x", keyHex, ciphertext)
179	return result, err
180}
181
182// RemoveDecryptionKey returns encrypted data without the decryption key
183func RemoveDecryptionKey(encryptData string) string {
184	vals := strings.Split(encryptData, "$")
185	if len(vals) == 4 {
186		return fmt.Sprintf("$%v$%v", vals[1], vals[3])
187	}
188	return encryptData
189}
190
191// DecryptData decrypts data encrypted using EncryptData
192func DecryptData(data string) (string, error) {
193	var result string
194	vals := strings.Split(data, "$")
195	if len(vals) != 4 {
196		return "", errors.New("data to decrypt is not in the correct format")
197	}
198	key := vals[2]
199	encrypted, err := hex.DecodeString(vals[3])
200	if err != nil {
201		return result, err
202	}
203	block, err := aes.NewCipher([]byte(key))
204	if err != nil {
205		return result, err
206	}
207	gcm, err := cipher.NewGCM(block)
208	if err != nil {
209		return result, err
210	}
211	nonceSize := gcm.NonceSize()
212	if len(encrypted) < nonceSize {
213		return result, errors.New("malformed ciphertext")
214	}
215	nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:]
216	plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
217	if err != nil {
218		return result, err
219	}
220	return string(plaintext), nil
221}
222
223// GenerateRSAKeys generate rsa private and public keys and write the
224// private key to specified file and the public key to the specified
225// file adding the .pub suffix
226func GenerateRSAKeys(file string) error {
227	if err := createDirPathIfMissing(file, 0700); err != nil {
228		return err
229	}
230	key, err := rsa.GenerateKey(rand.Reader, 4096)
231	if err != nil {
232		return err
233	}
234
235	o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
236	if err != nil {
237		return err
238	}
239	defer o.Close()
240
241	priv := &pem.Block{
242		Type:  "RSA PRIVATE KEY",
243		Bytes: x509.MarshalPKCS1PrivateKey(key),
244	}
245
246	if err := pem.Encode(o, priv); err != nil {
247		return err
248	}
249
250	pub, err := ssh.NewPublicKey(&key.PublicKey)
251	if err != nil {
252		return err
253	}
254	return os.WriteFile(file+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
255}
256
257// GenerateECDSAKeys generate ecdsa private and public keys and write the
258// private key to specified file and the public key to the specified
259// file adding the .pub suffix
260func GenerateECDSAKeys(file string) error {
261	if err := createDirPathIfMissing(file, 0700); err != nil {
262		return err
263	}
264	key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
265	if err != nil {
266		return err
267	}
268
269	keyBytes, err := x509.MarshalECPrivateKey(key)
270	if err != nil {
271		return err
272	}
273	priv := &pem.Block{
274		Type:  "EC PRIVATE KEY",
275		Bytes: keyBytes,
276	}
277
278	o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
279	if err != nil {
280		return err
281	}
282	defer o.Close()
283
284	if err := pem.Encode(o, priv); err != nil {
285		return err
286	}
287
288	pub, err := ssh.NewPublicKey(&key.PublicKey)
289	if err != nil {
290		return err
291	}
292	return os.WriteFile(file+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
293}
294
295// GenerateEd25519Keys generate ed25519 private and public keys and write the
296// private key to specified file and the public key to the specified
297// file adding the .pub suffix
298func GenerateEd25519Keys(file string) error {
299	pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
300	if err != nil {
301		return err
302	}
303	keyBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
304	if err != nil {
305		return err
306	}
307	priv := &pem.Block{
308		Type:  "PRIVATE KEY",
309		Bytes: keyBytes,
310	}
311	o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
312	if err != nil {
313		return err
314	}
315	defer o.Close()
316
317	if err := pem.Encode(o, priv); err != nil {
318		return err
319	}
320	pub, err := ssh.NewPublicKey(pubKey)
321	if err != nil {
322		return err
323	}
324	return os.WriteFile(file+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
325}
326
327// GetDirsForVirtualPath returns all the directory for the given path in reverse order
328// for example if the path is: /1/2/3/4 it returns:
329// [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ]
330func GetDirsForVirtualPath(virtualPath string) []string {
331	if virtualPath == "." {
332		virtualPath = "/"
333	} else {
334		if !path.IsAbs(virtualPath) {
335			virtualPath = CleanPath(virtualPath)
336		}
337	}
338	dirsForPath := []string{virtualPath}
339	for {
340		if virtualPath == "/" {
341			break
342		}
343		virtualPath = path.Dir(virtualPath)
344		dirsForPath = append(dirsForPath, virtualPath)
345	}
346	return dirsForPath
347}
348
349// CleanPath returns a clean POSIX (/) absolute path to work with
350func CleanPath(p string) string {
351	p = filepath.ToSlash(p)
352	if !path.IsAbs(p) {
353		p = "/" + p
354	}
355	return path.Clean(p)
356}
357
358// LoadTemplate parses the given template paths.
359// It behaves like template.Must but it writes a log before exiting.
360// You can optionally provide a base template (e.g. to define some custom functions)
361func LoadTemplate(base *template.Template, paths ...string) *template.Template {
362	var t *template.Template
363	var err error
364
365	if base != nil {
366		base, err = base.Clone()
367		if err == nil {
368			t, err = base.ParseFiles(paths...)
369		}
370	} else {
371		t, err = template.ParseFiles(paths...)
372	}
373
374	if err != nil {
375		logger.ErrorToConsole("error loading required template: %v", err)
376		logger.Error(logSender, "", "error loading required template: %v", err)
377		panic(err)
378	}
379	return t
380}
381
382// IsFileInputValid returns true this is a valid file name.
383// This method must be used before joining a file name, generally provided as
384// user input, with a directory
385func IsFileInputValid(fileInput string) bool {
386	cleanInput := filepath.Clean(fileInput)
387	if cleanInput == "." || cleanInput == ".." {
388		return false
389	}
390	return true
391}
392
393// CleanDirInput sanitizes user input for directories.
394// On Windows it removes any trailing `"`.
395// We try to help windows users that set an invalid path such as "C:\ProgramData\SFTPGO\".
396// This will only help if the invalid path is the last argument, for example in this command:
397// sftpgo.exe serve -c "C:\ProgramData\SFTPGO\" -l "sftpgo.log"
398// the -l flag will be ignored and the -c flag will get the value `C:\ProgramData\SFTPGO" -l sftpgo.log`
399// since the backslash after SFTPGO escape the double quote. This is definitely a bad user input
400func CleanDirInput(dirInput string) string {
401	if runtime.GOOS == osWindows {
402		for strings.HasSuffix(dirInput, "\"") {
403			dirInput = strings.TrimSuffix(dirInput, "\"")
404		}
405	}
406	return filepath.Clean(dirInput)
407}
408
409func createDirPathIfMissing(file string, perm os.FileMode) error {
410	dirPath := filepath.Dir(file)
411	if _, err := os.Stat(dirPath); os.IsNotExist(err) {
412		err = os.MkdirAll(dirPath, perm)
413		if err != nil {
414			return err
415		}
416	}
417	return nil
418}
419
420// GenerateRandomBytes generates the secret to use for JWT auth
421func GenerateRandomBytes(length int) []byte {
422	b := make([]byte, length)
423	_, err := io.ReadFull(rand.Reader, b)
424	if err == nil {
425		return b
426	}
427
428	b = xid.New().Bytes()
429	for len(b) < length {
430		b = append(b, xid.New().Bytes()...)
431	}
432
433	return b[:length]
434}
435
436// GenerateUniqueID retuens an unique ID
437func GenerateUniqueID() string {
438	u, err := uuid.NewRandom()
439	if err != nil {
440		return xid.New().String()
441	}
442	return shortuuid.DefaultEncoder.Encode(u)
443}
444
445// HTTPListenAndServe is a wrapper for ListenAndServe that support both tcp
446// and Unix-domain sockets
447func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool, logSender string) error {
448	var listener net.Listener
449	var err error
450
451	if filepath.IsAbs(address) && runtime.GOOS != osWindows {
452		if !IsFileInputValid(address) {
453			return fmt.Errorf("invalid socket address %#v", address)
454		}
455		err = createDirPathIfMissing(address, os.ModePerm)
456		if err != nil {
457			logger.ErrorToConsole("error creating Unix-domain socket parent dir: %v", err)
458			logger.Error(logSender, "", "error creating Unix-domain socket parent dir: %v", err)
459		}
460		os.Remove(address)
461		listener, err = newListener("unix", address, srv.ReadTimeout, srv.WriteTimeout)
462	} else {
463		CheckTCP4Port(port)
464		listener, err = newListener("tcp", fmt.Sprintf("%s:%d", address, port), srv.ReadTimeout, srv.WriteTimeout)
465	}
466	if err != nil {
467		return err
468	}
469
470	logger.Info(logSender, "", "server listener registered, address: %v TLS enabled: %v", listener.Addr().String(), isTLS)
471
472	defer listener.Close()
473
474	if isTLS {
475		return srv.ServeTLS(listener, "", "")
476	}
477	return srv.Serve(listener)
478}
479
480// GetTLSCiphersFromNames returns the TLS ciphers from the specified names
481func GetTLSCiphersFromNames(cipherNames []string) []uint16 {
482	var ciphers []uint16
483
484	for _, name := range RemoveDuplicates(cipherNames) {
485		for _, c := range tls.CipherSuites() {
486			if c.Name == strings.TrimSpace(name) {
487				ciphers = append(ciphers, c.ID)
488			}
489		}
490	}
491
492	return ciphers
493}
494
495// EncodeTLSCertToPem returns the specified certificate PEM encoded.
496// This can be verified using openssl x509 -in cert.crt  -text -noout
497func EncodeTLSCertToPem(tlsCert *x509.Certificate) (string, error) {
498	if len(tlsCert.Raw) == 0 {
499		return "", errors.New("invalid x509 certificate, no der contents")
500	}
501	publicKeyBlock := pem.Block{
502		Type:  "CERTIFICATE",
503		Bytes: tlsCert.Raw,
504	}
505	return string(pem.EncodeToMemory(&publicKeyBlock)), nil
506}
507
508// CheckTCP4Port quits the app if bind on the given IPv4 port fails.
509// This is a ugly hack to avoid to bind on an already used port.
510// It is required on Windows only. Upstream does not consider this
511// behaviour a bug:
512// https://github.com/golang/go/issues/45150
513func CheckTCP4Port(port int) {
514	if runtime.GOOS != osWindows {
515		return
516	}
517	listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", port))
518	if err != nil {
519		logger.ErrorToConsole("unable to bind on tcp4 address: %v", err)
520		logger.Error(logSender, "", "unable to bind on tcp4 address: %v", err)
521		os.Exit(1)
522	}
523	listener.Close()
524}
525
526// IsByteArrayEmpty return true if the byte array is empty or a new line
527func IsByteArrayEmpty(b []byte) bool {
528	if len(b) == 0 {
529		return true
530	}
531	if bytes.Equal(b, []byte("\n")) {
532		return true
533	}
534	if bytes.Equal(b, []byte("\r\n")) {
535		return true
536	}
537	return false
538}
539
540// GetSSHPublicKeyAsString returns an SSH public key serialized as string
541func GetSSHPublicKeyAsString(pubKey []byte) (string, error) {
542	if len(pubKey) == 0 {
543		return "", nil
544	}
545	k, err := ssh.ParsePublicKey(pubKey)
546	if err != nil {
547		return "", err
548	}
549	return string(ssh.MarshalAuthorizedKey(k)), nil
550}
551
552// GetRealIP returns the ip address as result of parsing either the
553// X-Real-IP header or the X-Forwarded-For header
554func GetRealIP(r *http.Request) string {
555	var ip string
556
557	if clientIP := r.Header.Get(trueClientIP); clientIP != "" {
558		ip = clientIP
559	} else if xrip := r.Header.Get(xRealIP); xrip != "" {
560		ip = xrip
561	} else if clientIP := r.Header.Get(cfConnectingIP); clientIP != "" {
562		ip = clientIP
563	} else if xff := r.Header.Get(xForwardedFor); xff != "" {
564		i := strings.Index(xff, ", ")
565		if i == -1 {
566			i = len(xff)
567		}
568		ip = strings.TrimSpace(xff[:i])
569	}
570	if net.ParseIP(ip) == nil {
571		return ""
572	}
573
574	return ip
575}
576
577// GetHTTPLocalAddress returns the local address for an http.Request
578// or empty if it cannot be determined
579func GetHTTPLocalAddress(r *http.Request) string {
580	if r == nil {
581		return ""
582	}
583	localAddr, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr)
584	if ok {
585		return localAddr.String()
586	}
587	return ""
588}
589
590// ParseAllowedIPAndRanges returns a list of functions that allow to find if an
591// IP is equal or is contained within the allowed list
592func ParseAllowedIPAndRanges(allowed []string) ([]func(net.IP) bool, error) {
593	res := make([]func(net.IP) bool, len(allowed))
594	for i, allowFrom := range allowed {
595		if strings.LastIndex(allowFrom, "/") > 0 {
596			_, ipRange, err := net.ParseCIDR(allowFrom)
597			if err != nil {
598				return nil, fmt.Errorf("given string %q is not a valid IP range: %v", allowFrom, err)
599			}
600
601			res[i] = ipRange.Contains
602		} else {
603			allowed := net.ParseIP(allowFrom)
604			if allowed == nil {
605				return nil, fmt.Errorf("given string %q is not a valid IP address", allowFrom)
606			}
607
608			res[i] = allowed.Equal
609		}
610	}
611
612	return res, nil
613}
614
615// GetRedactedURL returns the url redacting the password if any
616func GetRedactedURL(rawurl string) string {
617	if !strings.HasPrefix(rawurl, "http") {
618		return rawurl
619	}
620	u, err := url.Parse(rawurl)
621	if err != nil {
622		return rawurl
623	}
624	return u.Redacted()
625}
626
627// PrependFileInfo prepends a file info to a slice in an efficient way.
628// We, optimistically, assume that the slice has enough capacity
629func PrependFileInfo(files []os.FileInfo, info os.FileInfo) []os.FileInfo {
630	files = append(files, nil)
631	copy(files[1:], files)
632	files[0] = info
633	return files
634}
635