1package testhelper 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "crypto/elliptic" 7 "crypto/rand" 8 "crypto/x509" 9 "encoding/base64" 10 "encoding/json" 11 "encoding/pem" 12 "fmt" 13 "io" 14 "math/big" 15 "net" 16 "os" 17 "os/exec" 18 "path/filepath" 19 "strings" 20 "syscall" 21 "testing" 22 "time" 23 24 "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" 25 log "github.com/sirupsen/logrus" 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 "gitlab.com/gitlab-org/gitaly/v14/internal/command" 29 "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config" 30 "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/storage" 31 "gitlab.com/gitlab-org/gitaly/v14/internal/helper/text" 32 "gitlab.com/gitlab-org/gitaly/v14/internal/metadata/featureflag" 33 "go.uber.org/goleak" 34 "google.golang.org/grpc/metadata" 35) 36 37const ( 38 // RepositoryAuthToken is the default token used to authenticate 39 // against other Gitaly servers. It is inject as part of the 40 // GitalyServers metadata. 41 RepositoryAuthToken = "the-secret-token" 42 // DefaultStorageName is the default name of the Gitaly storage. 43 DefaultStorageName = "default" 44) 45 46// SkipWithPraefect skips the test if it is being executed with Praefect in front 47// of the Gitaly. 48func SkipWithPraefect(t testing.TB, reason string) { 49 if os.Getenv("GITALY_TEST_WITH_PRAEFECT") == "YesPlease" { 50 t.Skipf(reason) 51 } 52} 53 54// MustReadFile returns the content of a file or fails at once. 55func MustReadFile(t testing.TB, filename string) []byte { 56 content, err := os.ReadFile(filename) 57 if err != nil { 58 t.Fatal(err) 59 } 60 61 return content 62} 63 64// GitlabTestStoragePath returns the storage path to the gitlab-test repo. 65func GitlabTestStoragePath() string { 66 if testDirectory == "" { 67 panic("you must call testhelper.Configure() before GitlabTestStoragePath()") 68 } 69 return filepath.Join(testDirectory, "storage") 70} 71 72// GitalyServersMetadataFromCfg returns a metadata pair for gitaly-servers to be used in 73// inter-gitaly operations. 74func GitalyServersMetadataFromCfg(t testing.TB, cfg config.Cfg) metadata.MD { 75 gitalyServers := storage.GitalyServers{} 76storages: 77 for _, s := range cfg.Storages { 78 // It picks up the first address configured: TLS, TCP or UNIX. 79 for _, addr := range []string{cfg.TLSListenAddr, cfg.ListenAddr, cfg.SocketPath} { 80 if addr != "" { 81 gitalyServers[s.Name] = storage.ServerInfo{ 82 Address: addr, 83 Token: cfg.Auth.Token, 84 } 85 continue storages 86 } 87 } 88 require.FailNow(t, "no address found on the config") 89 } 90 91 gitalyServersJSON, err := json.Marshal(gitalyServers) 92 if err != nil { 93 t.Fatal(err) 94 } 95 96 return metadata.Pairs("gitaly-servers", base64.StdEncoding.EncodeToString(gitalyServersJSON)) 97} 98 99// MustRunCommand runs a command with an optional standard input and returns the standard output, or fails. 100func MustRunCommand(t testing.TB, stdin io.Reader, name string, args ...string) []byte { 101 t.Helper() 102 103 if filepath.Base(name) == "git" { 104 require.Fail(t, "Please use gittest.Exec or gittest.ExecStream to run git commands.") 105 } 106 107 cmd := exec.Command(name, args...) 108 if stdin != nil { 109 cmd.Stdin = stdin 110 } 111 112 output, err := cmd.Output() 113 if err != nil { 114 stderr := err.(*exec.ExitError).Stderr 115 require.NoError(t, err, "%s %s: %s", name, args, stderr) 116 } 117 118 return output 119} 120 121// MustClose calls Close() on the Closer and fails the test in case it returns 122// an error. This function is useful when closing via `defer`, as a simple 123// `defer require.NoError(t, closer.Close())` would cause `closer.Close()` to 124// be executed early already. 125func MustClose(t testing.TB, closer io.Closer) { 126 require.NoError(t, closer.Close()) 127} 128 129// CopyFile copies a file at the path src to a file at the path dst 130func CopyFile(t testing.TB, src, dst string) { 131 fsrc, err := os.Open(src) 132 require.NoError(t, err) 133 defer MustClose(t, fsrc) 134 135 fdst, err := os.Create(dst) 136 require.NoError(t, err) 137 defer MustClose(t, fdst) 138 139 _, err = io.Copy(fdst, fsrc) 140 require.NoError(t, err) 141} 142 143// GetTemporaryGitalySocketFileName will return a unique, useable socket file name 144func GetTemporaryGitalySocketFileName(t testing.TB) string { 145 require.NotEmpty(t, testDirectory, "you must call testhelper.Configure() before GetTemporaryGitalySocketFileName()") 146 147 tmpfile, err := os.CreateTemp(testDirectory, "gitaly.socket.") 148 require.NoError(t, err) 149 150 name := tmpfile.Name() 151 require.NoError(t, tmpfile.Close()) 152 require.NoError(t, os.Remove(name)) 153 154 return name 155} 156 157// GetLocalhostListener listens on the next available TCP port and returns 158// the listener and the localhost address (host:port) string. 159func GetLocalhostListener(t testing.TB) (net.Listener, string) { 160 l, err := net.Listen("tcp", "localhost:0") 161 require.NoError(t, err) 162 163 addr := fmt.Sprintf("localhost:%d", l.Addr().(*net.TCPAddr).Port) 164 165 return l, addr 166} 167 168// MustHaveNoGoroutines panics if it finds any Goroutines running. 169func MustHaveNoGoroutines() { 170 if err := goleak.Find( 171 // opencensus has a "defaultWorker" which is started by the package's 172 // `init()` function. There is no way to stop this worker, so it will leak 173 // whenever we import the package. 174 goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), 175 ); err != nil { 176 panic(fmt.Errorf("goroutines running: %w", err)) 177 } 178} 179 180// MustHaveNoChildProcess panics if it finds a running or finished child 181// process. It waits for 2 seconds for processes to be cleaned up by other 182// goroutines. 183func MustHaveNoChildProcess() { 184 waitDone := make(chan struct{}) 185 go func() { 186 command.WaitAllDone() 187 close(waitDone) 188 }() 189 190 select { 191 case <-waitDone: 192 case <-time.After(2 * time.Second): 193 } 194 195 mustFindNoFinishedChildProcess() 196 mustFindNoRunningChildProcess() 197} 198 199func mustFindNoFinishedChildProcess() { 200 // Wait4(pid int, wstatus *WaitStatus, options int, rusage *Rusage) (wpid int, err error) 201 // 202 // We use pid -1 to wait for any child. We don't care about wstatus or 203 // rusage. Use WNOHANG to return immediately if there is no child waiting 204 // to be reaped. 205 wpid, err := syscall.Wait4(-1, nil, syscall.WNOHANG, nil) 206 if err == nil && wpid > 0 { 207 panic(fmt.Errorf("wait4 found child process %d", wpid)) 208 } 209} 210 211func mustFindNoRunningChildProcess() { 212 pgrep := exec.Command("pgrep", "-P", fmt.Sprintf("%d", os.Getpid())) 213 desc := fmt.Sprintf("%q", strings.Join(pgrep.Args, " ")) 214 215 out, err := pgrep.Output() 216 if err == nil { 217 pidsComma := strings.Replace(text.ChompBytes(out), "\n", ",", -1) 218 psOut, _ := exec.Command("ps", "-o", "pid,args", "-p", pidsComma).Output() 219 panic(fmt.Errorf("found running child processes %s:\n%s", pidsComma, psOut)) 220 } 221 222 if status, ok := command.ExitStatus(err); ok && status == 1 { 223 // Exit status 1 means no processes were found 224 return 225 } 226 227 panic(fmt.Errorf("%s: %w", desc, err)) 228} 229 230// ContextOpt returns a new context instance with the new additions to it. 231type ContextOpt func(context.Context) (context.Context, func()) 232 233// ContextWithTimeout allows to set provided timeout to the context. 234func ContextWithTimeout(duration time.Duration) ContextOpt { 235 return func(ctx context.Context) (context.Context, func()) { 236 return context.WithTimeout(ctx, duration) 237 } 238} 239 240// ContextWithLogger allows to inject provided logger into the context. 241func ContextWithLogger(logger *log.Entry) ContextOpt { 242 return func(ctx context.Context) (context.Context, func()) { 243 return ctxlogrus.ToContext(ctx, logger), func() {} 244 } 245} 246 247// Context returns a cancellable context. 248func Context(opts ...ContextOpt) (context.Context, func()) { 249 ctx, cancel := context.WithCancel(context.Background()) 250 for _, ff := range featureflag.All { 251 ctx = featureflag.IncomingCtxWithFeatureFlag(ctx, ff) 252 ctx = featureflag.OutgoingCtxWithFeatureFlags(ctx, ff) 253 } 254 255 cancels := make([]func(), len(opts)+1) 256 cancels[0] = cancel 257 for i, opt := range opts { 258 ctx, cancel = opt(ctx) 259 cancels[i+1] = cancel 260 } 261 262 return ctx, func() { 263 for i := len(cancels) - 1; i >= 0; i-- { 264 cancels[i]() 265 } 266 } 267} 268 269// TempDir is a wrapper around os.MkdirTemp that provides a cleanup function. 270func TempDir(t testing.TB) string { 271 if testDirectory == "" { 272 panic("you must call testhelper.Configure() before TempDir()") 273 } 274 275 tmpDir, err := os.MkdirTemp(testDirectory, "") 276 require.NoError(t, err) 277 t.Cleanup(func() { 278 require.NoError(t, os.RemoveAll(tmpDir)) 279 }) 280 281 return tmpDir 282} 283 284// Cleanup functions should be called in a defer statement 285// immediately after they are returned from a test helper 286type Cleanup func() 287 288// WriteExecutable ensures that the parent directory exists, and writes an executable with provided content 289func WriteExecutable(t testing.TB, path string, content []byte) { 290 dir := filepath.Dir(path) 291 292 require.NoError(t, os.MkdirAll(dir, 0o755)) 293 require.NoError(t, os.WriteFile(path, content, 0o755)) 294 295 t.Cleanup(func() { 296 assert.NoError(t, os.RemoveAll(dir)) 297 }) 298} 299 300// ModifyEnvironment will change an environment variable and return a func suitable 301// for `defer` to change the value back. 302func ModifyEnvironment(t testing.TB, key string, value string) func() { 303 t.Helper() 304 305 oldValue, hasOldValue := os.LookupEnv(key) 306 require.NoError(t, os.Setenv(key, value)) 307 return func() { 308 if hasOldValue { 309 require.NoError(t, os.Setenv(key, oldValue)) 310 } else { 311 require.NoError(t, os.Unsetenv(key)) 312 } 313 } 314} 315 316// GenerateCerts creates a certificate that can be used to establish TLS protected TCP connection. 317// It returns paths to the file with the certificate and its private key. 318func GenerateCerts(t *testing.T) (string, string) { 319 t.Helper() 320 321 rootCA := &x509.Certificate{ 322 SerialNumber: big.NewInt(1), 323 NotBefore: time.Now(), 324 NotAfter: time.Now().AddDate(0, 0, 1), 325 BasicConstraintsValid: true, 326 IsCA: true, 327 IPAddresses: []net.IP{net.ParseIP("0.0.0.0"), net.ParseIP("127.0.0.1"), net.ParseIP("::1"), net.ParseIP("::")}, 328 DNSNames: []string{"localhost"}, 329 KeyUsage: x509.KeyUsageCertSign, 330 } 331 332 caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 333 require.NoError(t, err) 334 335 caCert, err := x509.CreateCertificate(rand.Reader, rootCA, rootCA, &caKey.PublicKey, caKey) 336 require.NoError(t, err) 337 338 entityKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 339 require.NoError(t, err) 340 341 entityX509 := &x509.Certificate{ 342 SerialNumber: big.NewInt(2), 343 } 344 345 entityCert, err := x509.CreateCertificate(rand.Reader, rootCA, entityX509, &entityKey.PublicKey, caKey) 346 require.NoError(t, err) 347 348 certFile, err := os.CreateTemp(testDirectory, "") 349 require.NoError(t, err) 350 defer MustClose(t, certFile) 351 t.Cleanup(func() { 352 require.NoError(t, os.Remove(certFile.Name())) 353 }) 354 355 // create chained PEM file with CA and entity cert 356 for _, cert := range [][]byte{entityCert, caCert} { 357 require.NoError(t, 358 pem.Encode(certFile, &pem.Block{ 359 Type: "CERTIFICATE", 360 Bytes: cert, 361 }), 362 ) 363 } 364 365 keyFile, err := os.CreateTemp(testDirectory, "") 366 require.NoError(t, err) 367 defer MustClose(t, keyFile) 368 t.Cleanup(func() { 369 require.NoError(t, os.Remove(keyFile.Name())) 370 }) 371 372 entityKeyBytes, err := x509.MarshalECPrivateKey(entityKey) 373 require.NoError(t, err) 374 375 require.NoError(t, 376 pem.Encode(keyFile, &pem.Block{ 377 Type: "ECDSA PRIVATE KEY", 378 Bytes: entityKeyBytes, 379 }), 380 ) 381 382 return certFile.Name(), keyFile.Name() 383} 384