1// Package httpdtest provides utilities for testing the exposed REST API. 2package httpdtest 3 4import ( 5 "bytes" 6 "encoding/hex" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "io" 11 "net/http" 12 "net/url" 13 "path" 14 "strconv" 15 "strings" 16 17 "github.com/go-chi/render" 18 19 "github.com/drakkan/sftpgo/v2/common" 20 "github.com/drakkan/sftpgo/v2/dataprovider" 21 "github.com/drakkan/sftpgo/v2/httpclient" 22 "github.com/drakkan/sftpgo/v2/httpd" 23 "github.com/drakkan/sftpgo/v2/kms" 24 "github.com/drakkan/sftpgo/v2/util" 25 "github.com/drakkan/sftpgo/v2/version" 26 "github.com/drakkan/sftpgo/v2/vfs" 27) 28 29const ( 30 tokenPath = "/api/v2/token" 31 activeConnectionsPath = "/api/v2/connections" 32 quotasBasePath = "/api/v2/quotas" 33 quotaScanPath = "/api/v2/quotas/users/scans" 34 quotaScanVFolderPath = "/api/v2/quotas/folders/scans" 35 userPath = "/api/v2/users" 36 versionPath = "/api/v2/version" 37 folderPath = "/api/v2/folders" 38 serverStatusPath = "/api/v2/status" 39 dumpDataPath = "/api/v2/dumpdata" 40 loadDataPath = "/api/v2/loaddata" 41 defenderHosts = "/api/v2/defender/hosts" 42 defenderBanTime = "/api/v2/defender/bantime" 43 defenderUnban = "/api/v2/defender/unban" 44 defenderScore = "/api/v2/defender/score" 45 adminPath = "/api/v2/admins" 46 adminPwdPath = "/api/v2/admin/changepwd" 47 apiKeysPath = "/api/v2/apikeys" 48 retentionBasePath = "/api/v2/retention/users" 49 retentionChecksPath = "/api/v2/retention/users/checks" 50) 51 52const ( 53 defaultTokenAuthUser = "admin" 54 defaultTokenAuthPass = "password" 55) 56 57var ( 58 httpBaseURL = "http://127.0.0.1:8080" 59 jwtToken = "" 60) 61 62// SetBaseURL sets the base url to use for HTTP requests. 63// Default URL is "http://127.0.0.1:8080" 64func SetBaseURL(url string) { 65 httpBaseURL = url 66} 67 68// SetJWTToken sets the JWT token to use 69func SetJWTToken(token string) { 70 jwtToken = token 71} 72 73func sendHTTPRequest(method, url string, body io.Reader, contentType, token string) (*http.Response, error) { 74 req, err := http.NewRequest(method, url, body) 75 if err != nil { 76 return nil, err 77 } 78 if contentType != "" { 79 req.Header.Set("Content-Type", "application/json") 80 } 81 if token != "" { 82 req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", token)) 83 } 84 return httpclient.GetHTTPClient().Do(req) 85} 86 87func buildURLRelativeToBase(paths ...string) string { 88 // we need to use path.Join and not filepath.Join 89 // since filepath.Join will use backslash separator on Windows 90 p := path.Join(paths...) 91 return fmt.Sprintf("%s/%s", strings.TrimRight(httpBaseURL, "/"), strings.TrimLeft(p, "/")) 92} 93 94// GetToken tries to return a JWT token 95func GetToken(username, password string) (string, map[string]interface{}, error) { 96 req, err := http.NewRequest(http.MethodGet, buildURLRelativeToBase(tokenPath), nil) 97 if err != nil { 98 return "", nil, err 99 } 100 req.SetBasicAuth(username, password) 101 resp, err := httpclient.GetHTTPClient().Do(req) 102 if err != nil { 103 return "", nil, err 104 } 105 defer resp.Body.Close() 106 107 err = checkResponse(resp.StatusCode, http.StatusOK) 108 if err != nil { 109 return "", nil, err 110 } 111 responseHolder := make(map[string]interface{}) 112 err = render.DecodeJSON(resp.Body, &responseHolder) 113 if err != nil { 114 return "", nil, err 115 } 116 return responseHolder["access_token"].(string), responseHolder, nil 117} 118 119func getDefaultToken() string { 120 if jwtToken != "" { 121 return jwtToken 122 } 123 token, _, err := GetToken(defaultTokenAuthUser, defaultTokenAuthPass) 124 if err != nil { 125 return "" 126 } 127 return token 128} 129 130// AddUser adds a new user and checks the received HTTP Status code against expectedStatusCode. 131func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, []byte, error) { 132 var newUser dataprovider.User 133 var body []byte 134 userAsJSON, _ := json.Marshal(user) 135 resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(userPath), bytes.NewBuffer(userAsJSON), 136 "application/json", getDefaultToken()) 137 if err != nil { 138 return newUser, body, err 139 } 140 defer resp.Body.Close() 141 err = checkResponse(resp.StatusCode, expectedStatusCode) 142 if expectedStatusCode != http.StatusCreated { 143 body, _ = getResponseBody(resp) 144 return newUser, body, err 145 } 146 if err == nil { 147 err = render.DecodeJSON(resp.Body, &newUser) 148 } else { 149 body, _ = getResponseBody(resp) 150 } 151 if err == nil { 152 err = checkUser(&user, &newUser) 153 } 154 return newUser, body, err 155} 156 157// UpdateUserWithJSON update a user using the provided JSON as POST body 158func UpdateUserWithJSON(user dataprovider.User, expectedStatusCode int, disconnect string, userAsJSON []byte) (dataprovider.User, []byte, error) { 159 var newUser dataprovider.User 160 var body []byte 161 url, err := addDisconnectQueryParam(buildURLRelativeToBase(userPath, url.PathEscape(user.Username)), disconnect) 162 if err != nil { 163 return user, body, err 164 } 165 resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(userAsJSON), "application/json", 166 getDefaultToken()) 167 if err != nil { 168 return user, body, err 169 } 170 defer resp.Body.Close() 171 body, _ = getResponseBody(resp) 172 err = checkResponse(resp.StatusCode, expectedStatusCode) 173 if expectedStatusCode != http.StatusOK { 174 return newUser, body, err 175 } 176 if err == nil { 177 newUser, body, err = GetUserByUsername(user.Username, expectedStatusCode) 178 } 179 if err == nil { 180 err = checkUser(&user, &newUser) 181 } 182 return newUser, body, err 183} 184 185// UpdateUser updates an existing user and checks the received HTTP Status code against expectedStatusCode. 186func UpdateUser(user dataprovider.User, expectedStatusCode int, disconnect string) (dataprovider.User, []byte, error) { 187 userAsJSON, _ := json.Marshal(user) 188 return UpdateUserWithJSON(user, expectedStatusCode, disconnect, userAsJSON) 189} 190 191// RemoveUser removes an existing user and checks the received HTTP Status code against expectedStatusCode. 192func RemoveUser(user dataprovider.User, expectedStatusCode int) ([]byte, error) { 193 var body []byte 194 resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(userPath, url.PathEscape(user.Username)), 195 nil, "", getDefaultToken()) 196 if err != nil { 197 return body, err 198 } 199 defer resp.Body.Close() 200 body, _ = getResponseBody(resp) 201 return body, checkResponse(resp.StatusCode, expectedStatusCode) 202} 203 204// GetUserByUsername gets a user by username and checks the received HTTP Status code against expectedStatusCode. 205func GetUserByUsername(username string, expectedStatusCode int) (dataprovider.User, []byte, error) { 206 var user dataprovider.User 207 var body []byte 208 resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(userPath, url.PathEscape(username)), 209 nil, "", getDefaultToken()) 210 if err != nil { 211 return user, body, err 212 } 213 defer resp.Body.Close() 214 err = checkResponse(resp.StatusCode, expectedStatusCode) 215 if err == nil && expectedStatusCode == http.StatusOK { 216 err = render.DecodeJSON(resp.Body, &user) 217 } else { 218 body, _ = getResponseBody(resp) 219 } 220 return user, body, err 221} 222 223// GetUsers returns a list of users and checks the received HTTP Status code against expectedStatusCode. 224// The number of results can be limited specifying a limit. 225// Some results can be skipped specifying an offset. 226func GetUsers(limit, offset int64, expectedStatusCode int) ([]dataprovider.User, []byte, error) { 227 var users []dataprovider.User 228 var body []byte 229 url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(userPath), limit, offset) 230 if err != nil { 231 return users, body, err 232 } 233 resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) 234 if err != nil { 235 return users, body, err 236 } 237 defer resp.Body.Close() 238 err = checkResponse(resp.StatusCode, expectedStatusCode) 239 if err == nil && expectedStatusCode == http.StatusOK { 240 err = render.DecodeJSON(resp.Body, &users) 241 } else { 242 body, _ = getResponseBody(resp) 243 } 244 return users, body, err 245} 246 247// AddAdmin adds a new admin and checks the received HTTP Status code against expectedStatusCode. 248func AddAdmin(admin dataprovider.Admin, expectedStatusCode int) (dataprovider.Admin, []byte, error) { 249 var newAdmin dataprovider.Admin 250 var body []byte 251 asJSON, _ := json.Marshal(admin) 252 resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(adminPath), bytes.NewBuffer(asJSON), 253 "application/json", getDefaultToken()) 254 if err != nil { 255 return newAdmin, body, err 256 } 257 defer resp.Body.Close() 258 err = checkResponse(resp.StatusCode, expectedStatusCode) 259 if expectedStatusCode != http.StatusCreated { 260 body, _ = getResponseBody(resp) 261 return newAdmin, body, err 262 } 263 if err == nil { 264 err = render.DecodeJSON(resp.Body, &newAdmin) 265 } else { 266 body, _ = getResponseBody(resp) 267 } 268 if err == nil { 269 err = checkAdmin(&admin, &newAdmin) 270 } 271 return newAdmin, body, err 272} 273 274// UpdateAdmin updates an existing admin and checks the received HTTP Status code against expectedStatusCode 275func UpdateAdmin(admin dataprovider.Admin, expectedStatusCode int) (dataprovider.Admin, []byte, error) { 276 var newAdmin dataprovider.Admin 277 var body []byte 278 279 asJSON, _ := json.Marshal(admin) 280 resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(adminPath, url.PathEscape(admin.Username)), 281 bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) 282 if err != nil { 283 return newAdmin, body, err 284 } 285 defer resp.Body.Close() 286 body, _ = getResponseBody(resp) 287 err = checkResponse(resp.StatusCode, expectedStatusCode) 288 if expectedStatusCode != http.StatusOK { 289 return newAdmin, body, err 290 } 291 if err == nil { 292 newAdmin, body, err = GetAdminByUsername(admin.Username, expectedStatusCode) 293 } 294 if err == nil { 295 err = checkAdmin(&admin, &newAdmin) 296 } 297 return newAdmin, body, err 298} 299 300// RemoveAdmin removes an existing admin and checks the received HTTP Status code against expectedStatusCode. 301func RemoveAdmin(admin dataprovider.Admin, expectedStatusCode int) ([]byte, error) { 302 var body []byte 303 resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(adminPath, url.PathEscape(admin.Username)), 304 nil, "", getDefaultToken()) 305 if err != nil { 306 return body, err 307 } 308 defer resp.Body.Close() 309 body, _ = getResponseBody(resp) 310 return body, checkResponse(resp.StatusCode, expectedStatusCode) 311} 312 313// GetAdminByUsername gets an admin by username and checks the received HTTP Status code against expectedStatusCode. 314func GetAdminByUsername(username string, expectedStatusCode int) (dataprovider.Admin, []byte, error) { 315 var admin dataprovider.Admin 316 var body []byte 317 resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(adminPath, url.PathEscape(username)), 318 nil, "", getDefaultToken()) 319 if err != nil { 320 return admin, body, err 321 } 322 defer resp.Body.Close() 323 err = checkResponse(resp.StatusCode, expectedStatusCode) 324 if err == nil && expectedStatusCode == http.StatusOK { 325 err = render.DecodeJSON(resp.Body, &admin) 326 } else { 327 body, _ = getResponseBody(resp) 328 } 329 return admin, body, err 330} 331 332// GetAdmins returns a list of admins and checks the received HTTP Status code against expectedStatusCode. 333// The number of results can be limited specifying a limit. 334// Some results can be skipped specifying an offset. 335func GetAdmins(limit, offset int64, expectedStatusCode int) ([]dataprovider.Admin, []byte, error) { 336 var admins []dataprovider.Admin 337 var body []byte 338 url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(adminPath), limit, offset) 339 if err != nil { 340 return admins, body, err 341 } 342 resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) 343 if err != nil { 344 return admins, body, err 345 } 346 defer resp.Body.Close() 347 err = checkResponse(resp.StatusCode, expectedStatusCode) 348 if err == nil && expectedStatusCode == http.StatusOK { 349 err = render.DecodeJSON(resp.Body, &admins) 350 } else { 351 body, _ = getResponseBody(resp) 352 } 353 return admins, body, err 354} 355 356// ChangeAdminPassword changes the password for an existing admin 357func ChangeAdminPassword(currentPassword, newPassword string, expectedStatusCode int) ([]byte, error) { 358 var body []byte 359 360 pwdChange := make(map[string]string) 361 pwdChange["current_password"] = currentPassword 362 pwdChange["new_password"] = newPassword 363 364 asJSON, _ := json.Marshal(&pwdChange) 365 resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(adminPwdPath), 366 bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) 367 if err != nil { 368 return body, err 369 } 370 defer resp.Body.Close() 371 372 err = checkResponse(resp.StatusCode, expectedStatusCode) 373 body, _ = getResponseBody(resp) 374 375 return body, err 376} 377 378// GetAPIKeys returns a list of API keys and checks the received HTTP Status code against expectedStatusCode. 379// The number of results can be limited specifying a limit. 380// Some results can be skipped specifying an offset. 381func GetAPIKeys(limit, offset int64, expectedStatusCode int) ([]dataprovider.APIKey, []byte, error) { 382 var apiKeys []dataprovider.APIKey 383 var body []byte 384 url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(apiKeysPath), limit, offset) 385 if err != nil { 386 return apiKeys, body, err 387 } 388 resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) 389 if err != nil { 390 return apiKeys, body, err 391 } 392 defer resp.Body.Close() 393 err = checkResponse(resp.StatusCode, expectedStatusCode) 394 if err == nil && expectedStatusCode == http.StatusOK { 395 err = render.DecodeJSON(resp.Body, &apiKeys) 396 } else { 397 body, _ = getResponseBody(resp) 398 } 399 return apiKeys, body, err 400} 401 402// AddAPIKey adds a new API key and checks the received HTTP Status code against expectedStatusCode. 403func AddAPIKey(apiKey dataprovider.APIKey, expectedStatusCode int) (dataprovider.APIKey, []byte, error) { 404 var newAPIKey dataprovider.APIKey 405 var body []byte 406 asJSON, _ := json.Marshal(apiKey) 407 resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(apiKeysPath), bytes.NewBuffer(asJSON), 408 "application/json", getDefaultToken()) 409 if err != nil { 410 return newAPIKey, body, err 411 } 412 defer resp.Body.Close() 413 err = checkResponse(resp.StatusCode, expectedStatusCode) 414 if expectedStatusCode != http.StatusCreated { 415 body, _ = getResponseBody(resp) 416 return newAPIKey, body, err 417 } 418 if err != nil { 419 body, _ = getResponseBody(resp) 420 return newAPIKey, body, err 421 } 422 response := make(map[string]string) 423 err = render.DecodeJSON(resp.Body, &response) 424 if err == nil { 425 newAPIKey, body, err = GetAPIKeyByID(resp.Header.Get("X-Object-ID"), http.StatusOK) 426 } 427 if err == nil { 428 err = checkAPIKey(&apiKey, &newAPIKey) 429 } 430 newAPIKey.Key = response["key"] 431 432 return newAPIKey, body, err 433} 434 435// UpdateAPIKey updates an existing API key and checks the received HTTP Status code against expectedStatusCode 436func UpdateAPIKey(apiKey dataprovider.APIKey, expectedStatusCode int) (dataprovider.APIKey, []byte, error) { 437 var newAPIKey dataprovider.APIKey 438 var body []byte 439 440 asJSON, _ := json.Marshal(apiKey) 441 resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(apiKeysPath, url.PathEscape(apiKey.KeyID)), 442 bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) 443 if err != nil { 444 return newAPIKey, body, err 445 } 446 defer resp.Body.Close() 447 body, _ = getResponseBody(resp) 448 err = checkResponse(resp.StatusCode, expectedStatusCode) 449 if expectedStatusCode != http.StatusOK { 450 return newAPIKey, body, err 451 } 452 if err == nil { 453 newAPIKey, body, err = GetAPIKeyByID(apiKey.KeyID, expectedStatusCode) 454 } 455 if err == nil { 456 err = checkAPIKey(&apiKey, &newAPIKey) 457 } 458 return newAPIKey, body, err 459} 460 461// RemoveAPIKey removes an existing API key and checks the received HTTP Status code against expectedStatusCode. 462func RemoveAPIKey(apiKey dataprovider.APIKey, expectedStatusCode int) ([]byte, error) { 463 var body []byte 464 resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(apiKeysPath, url.PathEscape(apiKey.KeyID)), 465 nil, "", getDefaultToken()) 466 if err != nil { 467 return body, err 468 } 469 defer resp.Body.Close() 470 body, _ = getResponseBody(resp) 471 return body, checkResponse(resp.StatusCode, expectedStatusCode) 472} 473 474// GetAPIKeyByID gets a API key by ID and checks the received HTTP Status code against expectedStatusCode. 475func GetAPIKeyByID(keyID string, expectedStatusCode int) (dataprovider.APIKey, []byte, error) { 476 var apiKey dataprovider.APIKey 477 var body []byte 478 resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(apiKeysPath, url.PathEscape(keyID)), 479 nil, "", getDefaultToken()) 480 if err != nil { 481 return apiKey, body, err 482 } 483 defer resp.Body.Close() 484 err = checkResponse(resp.StatusCode, expectedStatusCode) 485 if err == nil && expectedStatusCode == http.StatusOK { 486 err = render.DecodeJSON(resp.Body, &apiKey) 487 } else { 488 body, _ = getResponseBody(resp) 489 } 490 return apiKey, body, err 491} 492 493// GetQuotaScans gets active quota scans for users and checks the received HTTP Status code against expectedStatusCode. 494func GetQuotaScans(expectedStatusCode int) ([]common.ActiveQuotaScan, []byte, error) { 495 var quotaScans []common.ActiveQuotaScan 496 var body []byte 497 resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanPath), nil, "", getDefaultToken()) 498 if err != nil { 499 return quotaScans, body, err 500 } 501 defer resp.Body.Close() 502 err = checkResponse(resp.StatusCode, expectedStatusCode) 503 if err == nil && expectedStatusCode == http.StatusOK { 504 err = render.DecodeJSON(resp.Body, "aScans) 505 } else { 506 body, _ = getResponseBody(resp) 507 } 508 return quotaScans, body, err 509} 510 511// StartQuotaScan starts a new quota scan for the given user and checks the received HTTP Status code against expectedStatusCode. 512func StartQuotaScan(user dataprovider.User, expectedStatusCode int) ([]byte, error) { 513 var body []byte 514 resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(quotasBasePath, "users", user.Username, "scan"), 515 nil, "", getDefaultToken()) 516 if err != nil { 517 return body, err 518 } 519 defer resp.Body.Close() 520 body, _ = getResponseBody(resp) 521 return body, checkResponse(resp.StatusCode, expectedStatusCode) 522} 523 524// UpdateQuotaUsage updates the user used quota limits and checks the received HTTP Status code against expectedStatusCode. 525func UpdateQuotaUsage(user dataprovider.User, mode string, expectedStatusCode int) ([]byte, error) { 526 var body []byte 527 userAsJSON, _ := json.Marshal(user) 528 url, err := addModeQueryParam(buildURLRelativeToBase(quotasBasePath, "users", user.Username, "usage"), mode) 529 if err != nil { 530 return body, err 531 } 532 resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(userAsJSON), "application/json", 533 getDefaultToken()) 534 if err != nil { 535 return body, err 536 } 537 defer resp.Body.Close() 538 body, _ = getResponseBody(resp) 539 return body, checkResponse(resp.StatusCode, expectedStatusCode) 540} 541 542// GetRetentionChecks returns the active retention checks 543func GetRetentionChecks(expectedStatusCode int) ([]common.ActiveRetentionChecks, []byte, error) { 544 var checks []common.ActiveRetentionChecks 545 var body []byte 546 resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(retentionChecksPath), nil, "", getDefaultToken()) 547 if err != nil { 548 return checks, body, err 549 } 550 defer resp.Body.Close() 551 err = checkResponse(resp.StatusCode, expectedStatusCode) 552 if err == nil && expectedStatusCode == http.StatusOK { 553 err = render.DecodeJSON(resp.Body, &checks) 554 } else { 555 body, _ = getResponseBody(resp) 556 } 557 return checks, body, err 558} 559 560// StartRetentionCheck starts a new retention check 561func StartRetentionCheck(username string, retention []common.FolderRetention, expectedStatusCode int) ([]byte, error) { 562 var body []byte 563 asJSON, _ := json.Marshal(retention) 564 resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(retentionBasePath, username, "check"), 565 bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) 566 if err != nil { 567 return body, err 568 } 569 defer resp.Body.Close() 570 body, _ = getResponseBody(resp) 571 return body, checkResponse(resp.StatusCode, expectedStatusCode) 572} 573 574// GetConnections returns status and stats for active SFTP/SCP connections 575func GetConnections(expectedStatusCode int) ([]common.ConnectionStatus, []byte, error) { 576 var connections []common.ConnectionStatus 577 var body []byte 578 resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(activeConnectionsPath), nil, "", getDefaultToken()) 579 if err != nil { 580 return connections, body, err 581 } 582 defer resp.Body.Close() 583 err = checkResponse(resp.StatusCode, expectedStatusCode) 584 if err == nil && expectedStatusCode == http.StatusOK { 585 err = render.DecodeJSON(resp.Body, &connections) 586 } else { 587 body, _ = getResponseBody(resp) 588 } 589 return connections, body, err 590} 591 592// CloseConnection closes an active connection identified by connectionID 593func CloseConnection(connectionID string, expectedStatusCode int) ([]byte, error) { 594 var body []byte 595 resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(activeConnectionsPath, connectionID), 596 nil, "", getDefaultToken()) 597 if err != nil { 598 return body, err 599 } 600 defer resp.Body.Close() 601 err = checkResponse(resp.StatusCode, expectedStatusCode) 602 body, _ = getResponseBody(resp) 603 return body, err 604} 605 606// AddFolder adds a new folder and checks the received HTTP Status code against expectedStatusCode 607func AddFolder(folder vfs.BaseVirtualFolder, expectedStatusCode int) (vfs.BaseVirtualFolder, []byte, error) { 608 var newFolder vfs.BaseVirtualFolder 609 var body []byte 610 folderAsJSON, _ := json.Marshal(folder) 611 resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(folderPath), bytes.NewBuffer(folderAsJSON), 612 "application/json", getDefaultToken()) 613 if err != nil { 614 return newFolder, body, err 615 } 616 defer resp.Body.Close() 617 err = checkResponse(resp.StatusCode, expectedStatusCode) 618 if expectedStatusCode != http.StatusCreated { 619 body, _ = getResponseBody(resp) 620 return newFolder, body, err 621 } 622 if err == nil { 623 err = render.DecodeJSON(resp.Body, &newFolder) 624 } else { 625 body, _ = getResponseBody(resp) 626 } 627 if err == nil { 628 err = checkFolder(&folder, &newFolder) 629 } 630 return newFolder, body, err 631} 632 633// UpdateFolder updates an existing folder and checks the received HTTP Status code against expectedStatusCode. 634func UpdateFolder(folder vfs.BaseVirtualFolder, expectedStatusCode int) (vfs.BaseVirtualFolder, []byte, error) { 635 var updatedFolder vfs.BaseVirtualFolder 636 var body []byte 637 638 folderAsJSON, _ := json.Marshal(folder) 639 resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(folderPath, url.PathEscape(folder.Name)), 640 bytes.NewBuffer(folderAsJSON), "application/json", getDefaultToken()) 641 if err != nil { 642 return updatedFolder, body, err 643 } 644 defer resp.Body.Close() 645 body, _ = getResponseBody(resp) 646 647 err = checkResponse(resp.StatusCode, expectedStatusCode) 648 if expectedStatusCode != http.StatusOK { 649 return updatedFolder, body, err 650 } 651 if err == nil { 652 updatedFolder, body, err = GetFolderByName(folder.Name, expectedStatusCode) 653 } 654 if err == nil { 655 err = checkFolder(&folder, &updatedFolder) 656 } 657 return updatedFolder, body, err 658} 659 660// RemoveFolder removes an existing user and checks the received HTTP Status code against expectedStatusCode. 661func RemoveFolder(folder vfs.BaseVirtualFolder, expectedStatusCode int) ([]byte, error) { 662 var body []byte 663 resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(folderPath, url.PathEscape(folder.Name)), 664 nil, "", getDefaultToken()) 665 if err != nil { 666 return body, err 667 } 668 defer resp.Body.Close() 669 body, _ = getResponseBody(resp) 670 return body, checkResponse(resp.StatusCode, expectedStatusCode) 671} 672 673// GetFolderByName gets a folder by name and checks the received HTTP Status code against expectedStatusCode. 674func GetFolderByName(name string, expectedStatusCode int) (vfs.BaseVirtualFolder, []byte, error) { 675 var folder vfs.BaseVirtualFolder 676 var body []byte 677 resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(folderPath, url.PathEscape(name)), 678 nil, "", getDefaultToken()) 679 if err != nil { 680 return folder, body, err 681 } 682 defer resp.Body.Close() 683 err = checkResponse(resp.StatusCode, expectedStatusCode) 684 if err == nil && expectedStatusCode == http.StatusOK { 685 err = render.DecodeJSON(resp.Body, &folder) 686 } else { 687 body, _ = getResponseBody(resp) 688 } 689 return folder, body, err 690} 691 692// GetFolders returns a list of folders and checks the received HTTP Status code against expectedStatusCode. 693// The number of results can be limited specifying a limit. 694// Some results can be skipped specifying an offset. 695// The results can be filtered specifying a folder path, the folder path filter is an exact match 696func GetFolders(limit int64, offset int64, expectedStatusCode int) ([]vfs.BaseVirtualFolder, []byte, error) { 697 var folders []vfs.BaseVirtualFolder 698 var body []byte 699 url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(folderPath), limit, offset) 700 if err != nil { 701 return folders, body, err 702 } 703 resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) 704 if err != nil { 705 return folders, body, err 706 } 707 defer resp.Body.Close() 708 err = checkResponse(resp.StatusCode, expectedStatusCode) 709 if err == nil && expectedStatusCode == http.StatusOK { 710 err = render.DecodeJSON(resp.Body, &folders) 711 } else { 712 body, _ = getResponseBody(resp) 713 } 714 return folders, body, err 715} 716 717// GetFoldersQuotaScans gets active quota scans for folders and checks the received HTTP Status code against expectedStatusCode. 718func GetFoldersQuotaScans(expectedStatusCode int) ([]common.ActiveVirtualFolderQuotaScan, []byte, error) { 719 var quotaScans []common.ActiveVirtualFolderQuotaScan 720 var body []byte 721 resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanVFolderPath), nil, "", getDefaultToken()) 722 if err != nil { 723 return quotaScans, body, err 724 } 725 defer resp.Body.Close() 726 err = checkResponse(resp.StatusCode, expectedStatusCode) 727 if err == nil && expectedStatusCode == http.StatusOK { 728 err = render.DecodeJSON(resp.Body, "aScans) 729 } else { 730 body, _ = getResponseBody(resp) 731 } 732 return quotaScans, body, err 733} 734 735// StartFolderQuotaScan start a new quota scan for the given folder and checks the received HTTP Status code against expectedStatusCode. 736func StartFolderQuotaScan(folder vfs.BaseVirtualFolder, expectedStatusCode int) ([]byte, error) { 737 var body []byte 738 resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(quotasBasePath, "folders", folder.Name, "scan"), 739 nil, "", getDefaultToken()) 740 if err != nil { 741 return body, err 742 } 743 defer resp.Body.Close() 744 body, _ = getResponseBody(resp) 745 return body, checkResponse(resp.StatusCode, expectedStatusCode) 746} 747 748// UpdateFolderQuotaUsage updates the folder used quota limits and checks the received HTTP Status code against expectedStatusCode. 749func UpdateFolderQuotaUsage(folder vfs.BaseVirtualFolder, mode string, expectedStatusCode int) ([]byte, error) { 750 var body []byte 751 folderAsJSON, _ := json.Marshal(folder) 752 url, err := addModeQueryParam(buildURLRelativeToBase(quotasBasePath, "folders", folder.Name, "usage"), mode) 753 if err != nil { 754 return body, err 755 } 756 resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(folderAsJSON), "", getDefaultToken()) 757 if err != nil { 758 return body, err 759 } 760 defer resp.Body.Close() 761 body, _ = getResponseBody(resp) 762 return body, checkResponse(resp.StatusCode, expectedStatusCode) 763} 764 765// GetVersion returns version details 766func GetVersion(expectedStatusCode int) (version.Info, []byte, error) { 767 var appVersion version.Info 768 var body []byte 769 resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(versionPath), nil, "", getDefaultToken()) 770 if err != nil { 771 return appVersion, body, err 772 } 773 defer resp.Body.Close() 774 err = checkResponse(resp.StatusCode, expectedStatusCode) 775 if err == nil && expectedStatusCode == http.StatusOK { 776 err = render.DecodeJSON(resp.Body, &appVersion) 777 } else { 778 body, _ = getResponseBody(resp) 779 } 780 return appVersion, body, err 781} 782 783// GetStatus returns the server status 784func GetStatus(expectedStatusCode int) (httpd.ServicesStatus, []byte, error) { 785 var response httpd.ServicesStatus 786 var body []byte 787 resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(serverStatusPath), nil, "", getDefaultToken()) 788 if err != nil { 789 return response, body, err 790 } 791 defer resp.Body.Close() 792 err = checkResponse(resp.StatusCode, expectedStatusCode) 793 if err == nil && (expectedStatusCode == http.StatusOK) { 794 err = render.DecodeJSON(resp.Body, &response) 795 } else { 796 body, _ = getResponseBody(resp) 797 } 798 return response, body, err 799} 800 801// GetDefenderHosts returns hosts that are banned or for which some violations have been detected 802func GetDefenderHosts(expectedStatusCode int) ([]common.DefenderEntry, []byte, error) { 803 var response []common.DefenderEntry 804 var body []byte 805 url, err := url.Parse(buildURLRelativeToBase(defenderHosts)) 806 if err != nil { 807 return response, body, err 808 } 809 resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) 810 if err != nil { 811 return response, body, err 812 } 813 defer resp.Body.Close() 814 err = checkResponse(resp.StatusCode, expectedStatusCode) 815 if err == nil && expectedStatusCode == http.StatusOK { 816 err = render.DecodeJSON(resp.Body, &response) 817 } else { 818 body, _ = getResponseBody(resp) 819 } 820 return response, body, err 821} 822 823// GetDefenderHostByIP returns the host with the given IP, if it exists 824func GetDefenderHostByIP(ip string, expectedStatusCode int) (common.DefenderEntry, []byte, error) { 825 var host common.DefenderEntry 826 var body []byte 827 id := hex.EncodeToString([]byte(ip)) 828 resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(defenderHosts, id), 829 nil, "", getDefaultToken()) 830 if err != nil { 831 return host, body, err 832 } 833 defer resp.Body.Close() 834 err = checkResponse(resp.StatusCode, expectedStatusCode) 835 if err == nil && expectedStatusCode == http.StatusOK { 836 err = render.DecodeJSON(resp.Body, &host) 837 } else { 838 body, _ = getResponseBody(resp) 839 } 840 return host, body, err 841} 842 843// RemoveDefenderHostByIP removes the host with the given IP from the defender list 844func RemoveDefenderHostByIP(ip string, expectedStatusCode int) ([]byte, error) { 845 var body []byte 846 id := hex.EncodeToString([]byte(ip)) 847 resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(defenderHosts, id), nil, "", getDefaultToken()) 848 if err != nil { 849 return body, err 850 } 851 defer resp.Body.Close() 852 body, _ = getResponseBody(resp) 853 return body, checkResponse(resp.StatusCode, expectedStatusCode) 854} 855 856// GetBanTime returns the ban time for the given IP address 857func GetBanTime(ip string, expectedStatusCode int) (map[string]interface{}, []byte, error) { 858 var response map[string]interface{} 859 var body []byte 860 url, err := url.Parse(buildURLRelativeToBase(defenderBanTime)) 861 if err != nil { 862 return response, body, err 863 } 864 q := url.Query() 865 q.Add("ip", ip) 866 url.RawQuery = q.Encode() 867 resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) 868 if err != nil { 869 return response, body, err 870 } 871 defer resp.Body.Close() 872 err = checkResponse(resp.StatusCode, expectedStatusCode) 873 if err == nil && expectedStatusCode == http.StatusOK { 874 err = render.DecodeJSON(resp.Body, &response) 875 } else { 876 body, _ = getResponseBody(resp) 877 } 878 return response, body, err 879} 880 881// GetScore returns the score for the given IP address 882func GetScore(ip string, expectedStatusCode int) (map[string]interface{}, []byte, error) { 883 var response map[string]interface{} 884 var body []byte 885 url, err := url.Parse(buildURLRelativeToBase(defenderScore)) 886 if err != nil { 887 return response, body, err 888 } 889 q := url.Query() 890 q.Add("ip", ip) 891 url.RawQuery = q.Encode() 892 resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) 893 if err != nil { 894 return response, body, err 895 } 896 defer resp.Body.Close() 897 err = checkResponse(resp.StatusCode, expectedStatusCode) 898 if err == nil && expectedStatusCode == http.StatusOK { 899 err = render.DecodeJSON(resp.Body, &response) 900 } else { 901 body, _ = getResponseBody(resp) 902 } 903 return response, body, err 904} 905 906// UnbanIP unbans the given IP address 907func UnbanIP(ip string, expectedStatusCode int) error { 908 postBody := make(map[string]string) 909 postBody["ip"] = ip 910 asJSON, _ := json.Marshal(postBody) 911 resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(defenderUnban), bytes.NewBuffer(asJSON), 912 "", getDefaultToken()) 913 if err != nil { 914 return err 915 } 916 defer resp.Body.Close() 917 return checkResponse(resp.StatusCode, expectedStatusCode) 918} 919 920// Dumpdata requests a backup to outputFile. 921// outputFile is relative to the configured backups_path 922func Dumpdata(outputFile, outputData, indent string, expectedStatusCode int) (map[string]interface{}, []byte, error) { 923 var response map[string]interface{} 924 var body []byte 925 url, err := url.Parse(buildURLRelativeToBase(dumpDataPath)) 926 if err != nil { 927 return response, body, err 928 } 929 q := url.Query() 930 if outputData != "" { 931 q.Add("output-data", outputData) 932 } 933 if outputFile != "" { 934 q.Add("output-file", outputFile) 935 } 936 if indent != "" { 937 q.Add("indent", indent) 938 } 939 url.RawQuery = q.Encode() 940 resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) 941 if err != nil { 942 return response, body, err 943 } 944 defer resp.Body.Close() 945 err = checkResponse(resp.StatusCode, expectedStatusCode) 946 if err == nil && expectedStatusCode == http.StatusOK { 947 err = render.DecodeJSON(resp.Body, &response) 948 } else { 949 body, _ = getResponseBody(resp) 950 } 951 return response, body, err 952} 953 954// Loaddata restores a backup. 955func Loaddata(inputFile, scanQuota, mode string, expectedStatusCode int) (map[string]interface{}, []byte, error) { 956 var response map[string]interface{} 957 var body []byte 958 url, err := url.Parse(buildURLRelativeToBase(loadDataPath)) 959 if err != nil { 960 return response, body, err 961 } 962 q := url.Query() 963 q.Add("input-file", inputFile) 964 if scanQuota != "" { 965 q.Add("scan-quota", scanQuota) 966 } 967 if mode != "" { 968 q.Add("mode", mode) 969 } 970 url.RawQuery = q.Encode() 971 resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) 972 if err != nil { 973 return response, body, err 974 } 975 defer resp.Body.Close() 976 err = checkResponse(resp.StatusCode, expectedStatusCode) 977 if err == nil && expectedStatusCode == http.StatusOK { 978 err = render.DecodeJSON(resp.Body, &response) 979 } else { 980 body, _ = getResponseBody(resp) 981 } 982 return response, body, err 983} 984 985// LoaddataFromPostBody restores a backup 986func LoaddataFromPostBody(data []byte, scanQuota, mode string, expectedStatusCode int) (map[string]interface{}, []byte, error) { 987 var response map[string]interface{} 988 var body []byte 989 url, err := url.Parse(buildURLRelativeToBase(loadDataPath)) 990 if err != nil { 991 return response, body, err 992 } 993 q := url.Query() 994 if scanQuota != "" { 995 q.Add("scan-quota", scanQuota) 996 } 997 if mode != "" { 998 q.Add("mode", mode) 999 } 1000 url.RawQuery = q.Encode() 1001 resp, err := sendHTTPRequest(http.MethodPost, url.String(), bytes.NewReader(data), "", getDefaultToken()) 1002 if err != nil { 1003 return response, body, err 1004 } 1005 defer resp.Body.Close() 1006 err = checkResponse(resp.StatusCode, expectedStatusCode) 1007 if err == nil && expectedStatusCode == http.StatusOK { 1008 err = render.DecodeJSON(resp.Body, &response) 1009 } else { 1010 body, _ = getResponseBody(resp) 1011 } 1012 return response, body, err 1013} 1014 1015func checkResponse(actual int, expected int) error { 1016 if expected != actual { 1017 return fmt.Errorf("wrong status code: got %v want %v", actual, expected) 1018 } 1019 return nil 1020} 1021 1022func getResponseBody(resp *http.Response) ([]byte, error) { 1023 return io.ReadAll(resp.Body) 1024} 1025 1026func checkFolder(expected *vfs.BaseVirtualFolder, actual *vfs.BaseVirtualFolder) error { 1027 if expected.ID <= 0 { 1028 if actual.ID <= 0 { 1029 return errors.New("actual folder ID must be > 0") 1030 } 1031 } else { 1032 if actual.ID != expected.ID { 1033 return errors.New("folder ID mismatch") 1034 } 1035 } 1036 if expected.Name != actual.Name { 1037 return errors.New("name mismatch") 1038 } 1039 if expected.MappedPath != actual.MappedPath { 1040 return errors.New("mapped path mismatch") 1041 } 1042 if expected.Description != actual.Description { 1043 return errors.New("description mismatch") 1044 } 1045 return compareFsConfig(&expected.FsConfig, &actual.FsConfig) 1046} 1047 1048func checkAPIKey(expected, actual *dataprovider.APIKey) error { 1049 if actual.Key != "" { 1050 return errors.New("key must not be visible") 1051 } 1052 if actual.KeyID == "" { 1053 return errors.New("actual key_id cannot be empty") 1054 } 1055 if expected.Name != actual.Name { 1056 return errors.New("name mismatch") 1057 } 1058 if expected.Scope != actual.Scope { 1059 return errors.New("scope mismatch") 1060 } 1061 if actual.CreatedAt == 0 { 1062 return errors.New("created_at cannot be 0") 1063 } 1064 if actual.UpdatedAt == 0 { 1065 return errors.New("updated_at cannot be 0") 1066 } 1067 if expected.ExpiresAt != actual.ExpiresAt { 1068 return errors.New("expires_at mismatch") 1069 } 1070 if expected.Description != actual.Description { 1071 return errors.New("description mismatch") 1072 } 1073 if expected.User != actual.User { 1074 return errors.New("user mismatch") 1075 } 1076 if expected.Admin != actual.Admin { 1077 return errors.New("admin mismatch") 1078 } 1079 1080 return nil 1081} 1082 1083func checkAdmin(expected, actual *dataprovider.Admin) error { 1084 if actual.Password != "" { 1085 return errors.New("admin password must not be visible") 1086 } 1087 if expected.ID <= 0 { 1088 if actual.ID <= 0 { 1089 return errors.New("actual admin ID must be > 0") 1090 } 1091 } else { 1092 if actual.ID != expected.ID { 1093 return errors.New("admin ID mismatch") 1094 } 1095 } 1096 if expected.CreatedAt > 0 { 1097 if expected.CreatedAt != actual.CreatedAt { 1098 return fmt.Errorf("created_at mismatch %v != %v", expected.CreatedAt, actual.CreatedAt) 1099 } 1100 } 1101 if err := compareAdminEqualFields(expected, actual); err != nil { 1102 return err 1103 } 1104 if len(expected.Permissions) != len(actual.Permissions) { 1105 return errors.New("permissions mismatch") 1106 } 1107 for _, p := range expected.Permissions { 1108 if !util.IsStringInSlice(p, actual.Permissions) { 1109 return errors.New("permissions content mismatch") 1110 } 1111 } 1112 if len(expected.Filters.AllowList) != len(actual.Filters.AllowList) { 1113 return errors.New("allow list mismatch") 1114 } 1115 if expected.Filters.AllowAPIKeyAuth != actual.Filters.AllowAPIKeyAuth { 1116 return errors.New("allow_api_key_auth mismatch") 1117 } 1118 for _, v := range expected.Filters.AllowList { 1119 if !util.IsStringInSlice(v, actual.Filters.AllowList) { 1120 return errors.New("allow list content mismatch") 1121 } 1122 } 1123 1124 return nil 1125} 1126 1127func compareAdminEqualFields(expected *dataprovider.Admin, actual *dataprovider.Admin) error { 1128 if expected.Username != actual.Username { 1129 return errors.New("sername mismatch") 1130 } 1131 if expected.Email != actual.Email { 1132 return errors.New("email mismatch") 1133 } 1134 if expected.Status != actual.Status { 1135 return errors.New("status mismatch") 1136 } 1137 if expected.Description != actual.Description { 1138 return errors.New("description mismatch") 1139 } 1140 if expected.AdditionalInfo != actual.AdditionalInfo { 1141 return errors.New("additional info mismatch") 1142 } 1143 return nil 1144} 1145 1146func checkUser(expected *dataprovider.User, actual *dataprovider.User) error { 1147 if actual.Password != "" { 1148 return errors.New("user password must not be visible") 1149 } 1150 if expected.ID <= 0 { 1151 if actual.ID <= 0 { 1152 return errors.New("actual user ID must be > 0") 1153 } 1154 } else { 1155 if actual.ID != expected.ID { 1156 return errors.New("user ID mismatch") 1157 } 1158 } 1159 if expected.CreatedAt > 0 { 1160 if expected.CreatedAt != actual.CreatedAt { 1161 return fmt.Errorf("created_at mismatch %v != %v", expected.CreatedAt, actual.CreatedAt) 1162 } 1163 } 1164 1165 if expected.Email != actual.Email { 1166 return errors.New("email mismatch") 1167 } 1168 if err := compareUserPermissions(expected, actual); err != nil { 1169 return err 1170 } 1171 if err := compareUserFilters(expected, actual); err != nil { 1172 return err 1173 } 1174 if err := compareFsConfig(&expected.FsConfig, &actual.FsConfig); err != nil { 1175 return err 1176 } 1177 if err := compareUserVirtualFolders(expected, actual); err != nil { 1178 return err 1179 } 1180 return compareEqualsUserFields(expected, actual) 1181} 1182 1183func compareUserPermissions(expected *dataprovider.User, actual *dataprovider.User) error { 1184 if len(expected.Permissions) != len(actual.Permissions) { 1185 return errors.New("permissions mismatch") 1186 } 1187 for dir, perms := range expected.Permissions { 1188 if actualPerms, ok := actual.Permissions[dir]; ok { 1189 for _, v := range actualPerms { 1190 if !util.IsStringInSlice(v, perms) { 1191 return errors.New("permissions contents mismatch") 1192 } 1193 } 1194 } else { 1195 return errors.New("permissions directories mismatch") 1196 } 1197 } 1198 return nil 1199} 1200 1201func compareUserVirtualFolders(expected *dataprovider.User, actual *dataprovider.User) error { 1202 if len(actual.VirtualFolders) != len(expected.VirtualFolders) { 1203 return errors.New("virtual folders len mismatch") 1204 } 1205 for _, v := range actual.VirtualFolders { 1206 found := false 1207 for _, v1 := range expected.VirtualFolders { 1208 if path.Clean(v.VirtualPath) == path.Clean(v1.VirtualPath) { 1209 if err := checkFolder(&v1.BaseVirtualFolder, &v.BaseVirtualFolder); err != nil { 1210 return err 1211 } 1212 if v.QuotaSize != v1.QuotaSize { 1213 return errors.New("vfolder quota size mismatch") 1214 } 1215 if (v.QuotaFiles) != (v1.QuotaFiles) { 1216 return errors.New("vfolder quota files mismatch") 1217 } 1218 found = true 1219 break 1220 } 1221 } 1222 if !found { 1223 return errors.New("virtual folders mismatch") 1224 } 1225 } 1226 return nil 1227} 1228 1229func compareFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { 1230 if expected.Provider != actual.Provider { 1231 return errors.New("fs provider mismatch") 1232 } 1233 if err := compareS3Config(expected, actual); err != nil { 1234 return err 1235 } 1236 if err := compareGCSConfig(expected, actual); err != nil { 1237 return err 1238 } 1239 if err := compareAzBlobConfig(expected, actual); err != nil { 1240 return err 1241 } 1242 if err := checkEncryptedSecret(expected.CryptConfig.Passphrase, actual.CryptConfig.Passphrase); err != nil { 1243 return err 1244 } 1245 return compareSFTPFsConfig(expected, actual) 1246} 1247 1248func compareS3Config(expected *vfs.Filesystem, actual *vfs.Filesystem) error { //nolint:gocyclo 1249 if expected.S3Config.Bucket != actual.S3Config.Bucket { 1250 return errors.New("fs S3 bucket mismatch") 1251 } 1252 if expected.S3Config.Region != actual.S3Config.Region { 1253 return errors.New("fs S3 region mismatch") 1254 } 1255 if expected.S3Config.AccessKey != actual.S3Config.AccessKey { 1256 return errors.New("fs S3 access key mismatch") 1257 } 1258 if err := checkEncryptedSecret(expected.S3Config.AccessSecret, actual.S3Config.AccessSecret); err != nil { 1259 return fmt.Errorf("fs S3 access secret mismatch: %v", err) 1260 } 1261 if expected.S3Config.Endpoint != actual.S3Config.Endpoint { 1262 return errors.New("fs S3 endpoint mismatch") 1263 } 1264 if expected.S3Config.StorageClass != actual.S3Config.StorageClass { 1265 return errors.New("fs S3 storage class mismatch") 1266 } 1267 if expected.S3Config.ACL != actual.S3Config.ACL { 1268 return errors.New("fs S3 ACL mismatch") 1269 } 1270 if expected.S3Config.UploadPartSize != actual.S3Config.UploadPartSize { 1271 return errors.New("fs S3 upload part size mismatch") 1272 } 1273 if expected.S3Config.UploadConcurrency != actual.S3Config.UploadConcurrency { 1274 return errors.New("fs S3 upload concurrency mismatch") 1275 } 1276 if expected.S3Config.DownloadPartSize != actual.S3Config.DownloadPartSize { 1277 return errors.New("fs S3 download part size mismatch") 1278 } 1279 if expected.S3Config.DownloadConcurrency != actual.S3Config.DownloadConcurrency { 1280 return errors.New("fs S3 download concurrency mismatch") 1281 } 1282 if expected.S3Config.ForcePathStyle != actual.S3Config.ForcePathStyle { 1283 return errors.New("fs S3 force path style mismatch") 1284 } 1285 if expected.S3Config.DownloadPartMaxTime != actual.S3Config.DownloadPartMaxTime { 1286 return errors.New("fs S3 download part max time mismatch") 1287 } 1288 if expected.S3Config.KeyPrefix != actual.S3Config.KeyPrefix && 1289 expected.S3Config.KeyPrefix+"/" != actual.S3Config.KeyPrefix { 1290 return errors.New("fs S3 key prefix mismatch") 1291 } 1292 return nil 1293} 1294 1295func compareGCSConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { 1296 if expected.GCSConfig.Bucket != actual.GCSConfig.Bucket { 1297 return errors.New("GCS bucket mismatch") 1298 } 1299 if expected.GCSConfig.StorageClass != actual.GCSConfig.StorageClass { 1300 return errors.New("GCS storage class mismatch") 1301 } 1302 if expected.GCSConfig.ACL != actual.GCSConfig.ACL { 1303 return errors.New("GCS ACL mismatch") 1304 } 1305 if expected.GCSConfig.KeyPrefix != actual.GCSConfig.KeyPrefix && 1306 expected.GCSConfig.KeyPrefix+"/" != actual.GCSConfig.KeyPrefix { 1307 return errors.New("GCS key prefix mismatch") 1308 } 1309 if expected.GCSConfig.AutomaticCredentials != actual.GCSConfig.AutomaticCredentials { 1310 return errors.New("GCS automatic credentials mismatch") 1311 } 1312 return nil 1313} 1314 1315func compareSFTPFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { 1316 if expected.SFTPConfig.Endpoint != actual.SFTPConfig.Endpoint { 1317 return errors.New("SFTPFs endpoint mismatch") 1318 } 1319 if expected.SFTPConfig.Username != actual.SFTPConfig.Username { 1320 return errors.New("SFTPFs username mismatch") 1321 } 1322 if expected.SFTPConfig.DisableCouncurrentReads != actual.SFTPConfig.DisableCouncurrentReads { 1323 return errors.New("SFTPFs disable_concurrent_reads mismatch") 1324 } 1325 if expected.SFTPConfig.BufferSize != actual.SFTPConfig.BufferSize { 1326 return errors.New("SFTPFs buffer_size mismatch") 1327 } 1328 if err := checkEncryptedSecret(expected.SFTPConfig.Password, actual.SFTPConfig.Password); err != nil { 1329 return fmt.Errorf("SFTPFs password mismatch: %v", err) 1330 } 1331 if err := checkEncryptedSecret(expected.SFTPConfig.PrivateKey, actual.SFTPConfig.PrivateKey); err != nil { 1332 return fmt.Errorf("SFTPFs private key mismatch: %v", err) 1333 } 1334 if expected.SFTPConfig.Prefix != actual.SFTPConfig.Prefix { 1335 if expected.SFTPConfig.Prefix != "" && actual.SFTPConfig.Prefix != "/" { 1336 return errors.New("SFTPFs prefix mismatch") 1337 } 1338 } 1339 if len(expected.SFTPConfig.Fingerprints) != len(actual.SFTPConfig.Fingerprints) { 1340 return errors.New("SFTPFs fingerprints mismatch") 1341 } 1342 for _, value := range actual.SFTPConfig.Fingerprints { 1343 if !util.IsStringInSlice(value, expected.SFTPConfig.Fingerprints) { 1344 return errors.New("SFTPFs fingerprints mismatch") 1345 } 1346 } 1347 return nil 1348} 1349 1350func compareAzBlobConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { 1351 if expected.AzBlobConfig.Container != actual.AzBlobConfig.Container { 1352 return errors.New("azure Blob container mismatch") 1353 } 1354 if expected.AzBlobConfig.AccountName != actual.AzBlobConfig.AccountName { 1355 return errors.New("azure Blob account name mismatch") 1356 } 1357 if err := checkEncryptedSecret(expected.AzBlobConfig.AccountKey, actual.AzBlobConfig.AccountKey); err != nil { 1358 return fmt.Errorf("azure Blob account key mismatch: %v", err) 1359 } 1360 if expected.AzBlobConfig.Endpoint != actual.AzBlobConfig.Endpoint { 1361 return errors.New("azure Blob endpoint mismatch") 1362 } 1363 if err := checkEncryptedSecret(expected.AzBlobConfig.SASURL, actual.AzBlobConfig.SASURL); err != nil { 1364 return fmt.Errorf("azure Blob SAS URL mismatch: %v", err) 1365 } 1366 if expected.AzBlobConfig.UploadPartSize != actual.AzBlobConfig.UploadPartSize { 1367 return errors.New("azure Blob upload part size mismatch") 1368 } 1369 if expected.AzBlobConfig.UploadConcurrency != actual.AzBlobConfig.UploadConcurrency { 1370 return errors.New("azure Blob upload concurrency mismatch") 1371 } 1372 if expected.AzBlobConfig.KeyPrefix != actual.AzBlobConfig.KeyPrefix && 1373 expected.AzBlobConfig.KeyPrefix+"/" != actual.AzBlobConfig.KeyPrefix { 1374 return errors.New("azure Blob key prefix mismatch") 1375 } 1376 if expected.AzBlobConfig.UseEmulator != actual.AzBlobConfig.UseEmulator { 1377 return errors.New("azure Blob use emulator mismatch") 1378 } 1379 if expected.AzBlobConfig.AccessTier != actual.AzBlobConfig.AccessTier { 1380 return errors.New("azure Blob access tier mismatch") 1381 } 1382 return nil 1383} 1384 1385func areSecretEquals(expected, actual *kms.Secret) bool { 1386 if expected == nil && actual == nil { 1387 return true 1388 } 1389 if expected != nil && expected.IsEmpty() && actual == nil { 1390 return true 1391 } 1392 if actual != nil && actual.IsEmpty() && expected == nil { 1393 return true 1394 } 1395 return false 1396} 1397 1398func checkEncryptedSecret(expected, actual *kms.Secret) error { 1399 if areSecretEquals(expected, actual) { 1400 return nil 1401 } 1402 if expected == nil && actual != nil && !actual.IsEmpty() { 1403 return errors.New("secret mismatch") 1404 } 1405 if actual == nil && expected != nil && !expected.IsEmpty() { 1406 return errors.New("secret mismatch") 1407 } 1408 if expected.IsPlain() && actual.IsEncrypted() { 1409 if actual.GetPayload() == "" { 1410 return errors.New("invalid secret payload") 1411 } 1412 if actual.GetAdditionalData() != "" { 1413 return errors.New("invalid secret additional data") 1414 } 1415 if actual.GetKey() != "" { 1416 return errors.New("invalid secret key") 1417 } 1418 } else { 1419 if expected.GetStatus() != actual.GetStatus() || expected.GetPayload() != actual.GetPayload() { 1420 return errors.New("secret mismatch") 1421 } 1422 } 1423 return nil 1424} 1425 1426func compareUserFilterSubStructs(expected *dataprovider.User, actual *dataprovider.User) error { 1427 for _, IPMask := range expected.Filters.AllowedIP { 1428 if !util.IsStringInSlice(IPMask, actual.Filters.AllowedIP) { 1429 return errors.New("allowed IP contents mismatch") 1430 } 1431 } 1432 for _, IPMask := range expected.Filters.DeniedIP { 1433 if !util.IsStringInSlice(IPMask, actual.Filters.DeniedIP) { 1434 return errors.New("denied IP contents mismatch") 1435 } 1436 } 1437 for _, method := range expected.Filters.DeniedLoginMethods { 1438 if !util.IsStringInSlice(method, actual.Filters.DeniedLoginMethods) { 1439 return errors.New("denied login methods contents mismatch") 1440 } 1441 } 1442 for _, protocol := range expected.Filters.DeniedProtocols { 1443 if !util.IsStringInSlice(protocol, actual.Filters.DeniedProtocols) { 1444 return errors.New("denied protocols contents mismatch") 1445 } 1446 } 1447 for _, options := range expected.Filters.WebClient { 1448 if !util.IsStringInSlice(options, actual.Filters.WebClient) { 1449 return errors.New("web client options contents mismatch") 1450 } 1451 } 1452 if expected.Filters.Hooks.ExternalAuthDisabled != actual.Filters.Hooks.ExternalAuthDisabled { 1453 return errors.New("external_auth_disabled hook mismatch") 1454 } 1455 if expected.Filters.Hooks.PreLoginDisabled != actual.Filters.Hooks.PreLoginDisabled { 1456 return errors.New("pre_login_disabled hook mismatch") 1457 } 1458 if expected.Filters.Hooks.CheckPasswordDisabled != actual.Filters.Hooks.CheckPasswordDisabled { 1459 return errors.New("check_password_disabled hook mismatch") 1460 } 1461 if expected.Filters.DisableFsChecks != actual.Filters.DisableFsChecks { 1462 return errors.New("disable_fs_checks mismatch") 1463 } 1464 return nil 1465} 1466 1467func compareUserFilters(expected *dataprovider.User, actual *dataprovider.User) error { 1468 if len(expected.Filters.AllowedIP) != len(actual.Filters.AllowedIP) { 1469 return errors.New("allowed IP mismatch") 1470 } 1471 if len(expected.Filters.DeniedIP) != len(actual.Filters.DeniedIP) { 1472 return errors.New("denied IP mismatch") 1473 } 1474 if len(expected.Filters.DeniedLoginMethods) != len(actual.Filters.DeniedLoginMethods) { 1475 return errors.New("denied login methods mismatch") 1476 } 1477 if len(expected.Filters.DeniedProtocols) != len(actual.Filters.DeniedProtocols) { 1478 return errors.New("denied protocols mismatch") 1479 } 1480 if expected.Filters.MaxUploadFileSize != actual.Filters.MaxUploadFileSize { 1481 return errors.New("max upload file size mismatch") 1482 } 1483 if expected.Filters.TLSUsername != actual.Filters.TLSUsername { 1484 return errors.New("TLSUsername mismatch") 1485 } 1486 if len(expected.Filters.WebClient) != len(actual.Filters.WebClient) { 1487 return errors.New("WebClient filter mismatch") 1488 } 1489 if expected.Filters.AllowAPIKeyAuth != actual.Filters.AllowAPIKeyAuth { 1490 return errors.New("allow_api_key_auth mismatch") 1491 } 1492 if err := compareUserFilterSubStructs(expected, actual); err != nil { 1493 return err 1494 } 1495 return compareUserFilePatternsFilters(expected, actual) 1496} 1497 1498func checkFilterMatch(expected []string, actual []string) bool { 1499 if len(expected) != len(actual) { 1500 return false 1501 } 1502 for _, e := range expected { 1503 if !util.IsStringInSlice(strings.ToLower(e), actual) { 1504 return false 1505 } 1506 } 1507 return true 1508} 1509 1510func compareUserFilePatternsFilters(expected *dataprovider.User, actual *dataprovider.User) error { 1511 if len(expected.Filters.FilePatterns) != len(actual.Filters.FilePatterns) { 1512 return errors.New("file patterns mismatch") 1513 } 1514 for _, f := range expected.Filters.FilePatterns { 1515 found := false 1516 for _, f1 := range actual.Filters.FilePatterns { 1517 if path.Clean(f.Path) == path.Clean(f1.Path) { 1518 if !checkFilterMatch(f.AllowedPatterns, f1.AllowedPatterns) || 1519 !checkFilterMatch(f.DeniedPatterns, f1.DeniedPatterns) { 1520 return errors.New("file patterns contents mismatch") 1521 } 1522 found = true 1523 } 1524 } 1525 if !found { 1526 return errors.New("file patterns contents mismatch") 1527 } 1528 } 1529 return nil 1530} 1531 1532func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.User) error { 1533 if expected.Username != actual.Username { 1534 return errors.New("username mismatch") 1535 } 1536 if expected.HomeDir != actual.HomeDir { 1537 return errors.New("home dir mismatch") 1538 } 1539 if expected.UID != actual.UID { 1540 return errors.New("UID mismatch") 1541 } 1542 if expected.GID != actual.GID { 1543 return errors.New("GID mismatch") 1544 } 1545 if expected.MaxSessions != actual.MaxSessions { 1546 return errors.New("MaxSessions mismatch") 1547 } 1548 if expected.QuotaSize != actual.QuotaSize { 1549 return errors.New("QuotaSize mismatch") 1550 } 1551 if expected.QuotaFiles != actual.QuotaFiles { 1552 return errors.New("QuotaFiles mismatch") 1553 } 1554 if len(expected.Permissions) != len(actual.Permissions) { 1555 return errors.New("permissions mismatch") 1556 } 1557 if expected.UploadBandwidth != actual.UploadBandwidth { 1558 return errors.New("UploadBandwidth mismatch") 1559 } 1560 if expected.DownloadBandwidth != actual.DownloadBandwidth { 1561 return errors.New("DownloadBandwidth mismatch") 1562 } 1563 if expected.Status != actual.Status { 1564 return errors.New("status mismatch") 1565 } 1566 if expected.ExpirationDate != actual.ExpirationDate { 1567 return errors.New("ExpirationDate mismatch") 1568 } 1569 if expected.AdditionalInfo != actual.AdditionalInfo { 1570 return errors.New("AdditionalInfo mismatch") 1571 } 1572 if expected.Description != actual.Description { 1573 return errors.New("description mismatch") 1574 } 1575 return nil 1576} 1577 1578func addLimitAndOffsetQueryParams(rawurl string, limit, offset int64) (*url.URL, error) { 1579 url, err := url.Parse(rawurl) 1580 if err != nil { 1581 return nil, err 1582 } 1583 q := url.Query() 1584 if limit > 0 { 1585 q.Add("limit", strconv.FormatInt(limit, 10)) 1586 } 1587 if offset > 0 { 1588 q.Add("offset", strconv.FormatInt(offset, 10)) 1589 } 1590 url.RawQuery = q.Encode() 1591 return url, err 1592} 1593 1594func addModeQueryParam(rawurl, mode string) (*url.URL, error) { 1595 url, err := url.Parse(rawurl) 1596 if err != nil { 1597 return nil, err 1598 } 1599 q := url.Query() 1600 if len(mode) > 0 { 1601 q.Add("mode", mode) 1602 } 1603 url.RawQuery = q.Encode() 1604 return url, err 1605} 1606 1607func addDisconnectQueryParam(rawurl, disconnect string) (*url.URL, error) { 1608 url, err := url.Parse(rawurl) 1609 if err != nil { 1610 return nil, err 1611 } 1612 q := url.Query() 1613 if len(disconnect) > 0 { 1614 q.Add("disconnect", disconnect) 1615 } 1616 url.RawQuery = q.Encode() 1617 return url, err 1618} 1619