1package testhelper
2
3import (
4	"context"
5	"crypto/ecdsa"
6	"crypto/elliptic"
7	"crypto/rand"
8	"crypto/x509"
9	"encoding/base64"
10	"encoding/json"
11	"encoding/pem"
12	"fmt"
13	"io"
14	"math/big"
15	"net"
16	"os"
17	"os/exec"
18	"path/filepath"
19	"strings"
20	"syscall"
21	"testing"
22	"time"
23
24	"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
25	log "github.com/sirupsen/logrus"
26	"github.com/stretchr/testify/assert"
27	"github.com/stretchr/testify/require"
28	"gitlab.com/gitlab-org/gitaly/v14/internal/command"
29	"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config"
30	"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/storage"
31	"gitlab.com/gitlab-org/gitaly/v14/internal/helper/text"
32	"gitlab.com/gitlab-org/gitaly/v14/internal/metadata/featureflag"
33	"go.uber.org/goleak"
34	"google.golang.org/grpc/metadata"
35)
36
37const (
38	// RepositoryAuthToken is the default token used to authenticate
39	// against other Gitaly servers. It is inject as part of the
40	// GitalyServers metadata.
41	RepositoryAuthToken = "the-secret-token"
42	// DefaultStorageName is the default name of the Gitaly storage.
43	DefaultStorageName = "default"
44)
45
46// SkipWithPraefect skips the test if it is being executed with Praefect in front
47// of the Gitaly.
48func SkipWithPraefect(t testing.TB, reason string) {
49	if os.Getenv("GITALY_TEST_WITH_PRAEFECT") == "YesPlease" {
50		t.Skipf(reason)
51	}
52}
53
54// MustReadFile returns the content of a file or fails at once.
55func MustReadFile(t testing.TB, filename string) []byte {
56	content, err := os.ReadFile(filename)
57	if err != nil {
58		t.Fatal(err)
59	}
60
61	return content
62}
63
64// GitlabTestStoragePath returns the storage path to the gitlab-test repo.
65func GitlabTestStoragePath() string {
66	if testDirectory == "" {
67		panic("you must call testhelper.Configure() before GitlabTestStoragePath()")
68	}
69	return filepath.Join(testDirectory, "storage")
70}
71
72// GitalyServersMetadataFromCfg returns a metadata pair for gitaly-servers to be used in
73// inter-gitaly operations.
74func GitalyServersMetadataFromCfg(t testing.TB, cfg config.Cfg) metadata.MD {
75	gitalyServers := storage.GitalyServers{}
76storages:
77	for _, s := range cfg.Storages {
78		// It picks up the first address configured: TLS, TCP or UNIX.
79		for _, addr := range []string{cfg.TLSListenAddr, cfg.ListenAddr, cfg.SocketPath} {
80			if addr != "" {
81				gitalyServers[s.Name] = storage.ServerInfo{
82					Address: addr,
83					Token:   cfg.Auth.Token,
84				}
85				continue storages
86			}
87		}
88		require.FailNow(t, "no address found on the config")
89	}
90
91	gitalyServersJSON, err := json.Marshal(gitalyServers)
92	if err != nil {
93		t.Fatal(err)
94	}
95
96	return metadata.Pairs("gitaly-servers", base64.StdEncoding.EncodeToString(gitalyServersJSON))
97}
98
99// MustRunCommand runs a command with an optional standard input and returns the standard output, or fails.
100func MustRunCommand(t testing.TB, stdin io.Reader, name string, args ...string) []byte {
101	t.Helper()
102
103	if filepath.Base(name) == "git" {
104		require.Fail(t, "Please use gittest.Exec or gittest.ExecStream to run git commands.")
105	}
106
107	cmd := exec.Command(name, args...)
108	if stdin != nil {
109		cmd.Stdin = stdin
110	}
111
112	output, err := cmd.Output()
113	if err != nil {
114		stderr := err.(*exec.ExitError).Stderr
115		require.NoError(t, err, "%s %s: %s", name, args, stderr)
116	}
117
118	return output
119}
120
121// MustClose calls Close() on the Closer and fails the test in case it returns
122// an error. This function is useful when closing via `defer`, as a simple
123// `defer require.NoError(t, closer.Close())` would cause `closer.Close()` to
124// be executed early already.
125func MustClose(t testing.TB, closer io.Closer) {
126	require.NoError(t, closer.Close())
127}
128
129// CopyFile copies a file at the path src to a file at the path dst
130func CopyFile(t testing.TB, src, dst string) {
131	fsrc, err := os.Open(src)
132	require.NoError(t, err)
133	defer MustClose(t, fsrc)
134
135	fdst, err := os.Create(dst)
136	require.NoError(t, err)
137	defer MustClose(t, fdst)
138
139	_, err = io.Copy(fdst, fsrc)
140	require.NoError(t, err)
141}
142
143// GetTemporaryGitalySocketFileName will return a unique, useable socket file name
144func GetTemporaryGitalySocketFileName(t testing.TB) string {
145	require.NotEmpty(t, testDirectory, "you must call testhelper.Configure() before GetTemporaryGitalySocketFileName()")
146
147	tmpfile, err := os.CreateTemp(testDirectory, "gitaly.socket.")
148	require.NoError(t, err)
149
150	name := tmpfile.Name()
151	require.NoError(t, tmpfile.Close())
152	require.NoError(t, os.Remove(name))
153
154	return name
155}
156
157// GetLocalhostListener listens on the next available TCP port and returns
158// the listener and the localhost address (host:port) string.
159func GetLocalhostListener(t testing.TB) (net.Listener, string) {
160	l, err := net.Listen("tcp", "localhost:0")
161	require.NoError(t, err)
162
163	addr := fmt.Sprintf("localhost:%d", l.Addr().(*net.TCPAddr).Port)
164
165	return l, addr
166}
167
168// MustHaveNoGoroutines panics if it finds any Goroutines running.
169func MustHaveNoGoroutines() {
170	if err := goleak.Find(
171		// opencensus has a "defaultWorker" which is started by the package's
172		// `init()` function. There is no way to stop this worker, so it will leak
173		// whenever we import the package.
174		goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
175	); err != nil {
176		panic(fmt.Errorf("goroutines running: %w", err))
177	}
178}
179
180// MustHaveNoChildProcess panics if it finds a running or finished child
181// process. It waits for 2 seconds for processes to be cleaned up by other
182// goroutines.
183func MustHaveNoChildProcess() {
184	waitDone := make(chan struct{})
185	go func() {
186		command.WaitAllDone()
187		close(waitDone)
188	}()
189
190	select {
191	case <-waitDone:
192	case <-time.After(2 * time.Second):
193	}
194
195	mustFindNoFinishedChildProcess()
196	mustFindNoRunningChildProcess()
197}
198
199func mustFindNoFinishedChildProcess() {
200	// Wait4(pid int, wstatus *WaitStatus, options int, rusage *Rusage) (wpid int, err error)
201	//
202	// We use pid -1 to wait for any child. We don't care about wstatus or
203	// rusage. Use WNOHANG to return immediately if there is no child waiting
204	// to be reaped.
205	wpid, err := syscall.Wait4(-1, nil, syscall.WNOHANG, nil)
206	if err == nil && wpid > 0 {
207		panic(fmt.Errorf("wait4 found child process %d", wpid))
208	}
209}
210
211func mustFindNoRunningChildProcess() {
212	pgrep := exec.Command("pgrep", "-P", fmt.Sprintf("%d", os.Getpid()))
213	desc := fmt.Sprintf("%q", strings.Join(pgrep.Args, " "))
214
215	out, err := pgrep.Output()
216	if err == nil {
217		pidsComma := strings.Replace(text.ChompBytes(out), "\n", ",", -1)
218		psOut, _ := exec.Command("ps", "-o", "pid,args", "-p", pidsComma).Output()
219		panic(fmt.Errorf("found running child processes %s:\n%s", pidsComma, psOut))
220	}
221
222	if status, ok := command.ExitStatus(err); ok && status == 1 {
223		// Exit status 1 means no processes were found
224		return
225	}
226
227	panic(fmt.Errorf("%s: %w", desc, err))
228}
229
230// ContextOpt returns a new context instance with the new additions to it.
231type ContextOpt func(context.Context) (context.Context, func())
232
233// ContextWithTimeout allows to set provided timeout to the context.
234func ContextWithTimeout(duration time.Duration) ContextOpt {
235	return func(ctx context.Context) (context.Context, func()) {
236		return context.WithTimeout(ctx, duration)
237	}
238}
239
240// ContextWithLogger allows to inject provided logger into the context.
241func ContextWithLogger(logger *log.Entry) ContextOpt {
242	return func(ctx context.Context) (context.Context, func()) {
243		return ctxlogrus.ToContext(ctx, logger), func() {}
244	}
245}
246
247// Context returns a cancellable context.
248func Context(opts ...ContextOpt) (context.Context, func()) {
249	ctx, cancel := context.WithCancel(context.Background())
250	for _, ff := range featureflag.All {
251		ctx = featureflag.IncomingCtxWithFeatureFlag(ctx, ff)
252		ctx = featureflag.OutgoingCtxWithFeatureFlags(ctx, ff)
253	}
254
255	cancels := make([]func(), len(opts)+1)
256	cancels[0] = cancel
257	for i, opt := range opts {
258		ctx, cancel = opt(ctx)
259		cancels[i+1] = cancel
260	}
261
262	return ctx, func() {
263		for i := len(cancels) - 1; i >= 0; i-- {
264			cancels[i]()
265		}
266	}
267}
268
269// TempDir is a wrapper around os.MkdirTemp that provides a cleanup function.
270func TempDir(t testing.TB) string {
271	if testDirectory == "" {
272		panic("you must call testhelper.Configure() before TempDir()")
273	}
274
275	tmpDir, err := os.MkdirTemp(testDirectory, "")
276	require.NoError(t, err)
277	t.Cleanup(func() {
278		require.NoError(t, os.RemoveAll(tmpDir))
279	})
280
281	return tmpDir
282}
283
284// Cleanup functions should be called in a defer statement
285// immediately after they are returned from a test helper
286type Cleanup func()
287
288// WriteExecutable ensures that the parent directory exists, and writes an executable with provided content
289func WriteExecutable(t testing.TB, path string, content []byte) {
290	dir := filepath.Dir(path)
291
292	require.NoError(t, os.MkdirAll(dir, 0o755))
293	require.NoError(t, os.WriteFile(path, content, 0o755))
294
295	t.Cleanup(func() {
296		assert.NoError(t, os.RemoveAll(dir))
297	})
298}
299
300// ModifyEnvironment will change an environment variable and return a func suitable
301// for `defer` to change the value back.
302func ModifyEnvironment(t testing.TB, key string, value string) func() {
303	t.Helper()
304
305	oldValue, hasOldValue := os.LookupEnv(key)
306	require.NoError(t, os.Setenv(key, value))
307	return func() {
308		if hasOldValue {
309			require.NoError(t, os.Setenv(key, oldValue))
310		} else {
311			require.NoError(t, os.Unsetenv(key))
312		}
313	}
314}
315
316// GenerateCerts creates a certificate that can be used to establish TLS protected TCP connection.
317// It returns paths to the file with the certificate and its private key.
318func GenerateCerts(t *testing.T) (string, string) {
319	t.Helper()
320
321	rootCA := &x509.Certificate{
322		SerialNumber:          big.NewInt(1),
323		NotBefore:             time.Now(),
324		NotAfter:              time.Now().AddDate(0, 0, 1),
325		BasicConstraintsValid: true,
326		IsCA:                  true,
327		IPAddresses:           []net.IP{net.ParseIP("0.0.0.0"), net.ParseIP("127.0.0.1"), net.ParseIP("::1"), net.ParseIP("::")},
328		DNSNames:              []string{"localhost"},
329		KeyUsage:              x509.KeyUsageCertSign,
330	}
331
332	caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
333	require.NoError(t, err)
334
335	caCert, err := x509.CreateCertificate(rand.Reader, rootCA, rootCA, &caKey.PublicKey, caKey)
336	require.NoError(t, err)
337
338	entityKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
339	require.NoError(t, err)
340
341	entityX509 := &x509.Certificate{
342		SerialNumber: big.NewInt(2),
343	}
344
345	entityCert, err := x509.CreateCertificate(rand.Reader, rootCA, entityX509, &entityKey.PublicKey, caKey)
346	require.NoError(t, err)
347
348	certFile, err := os.CreateTemp(testDirectory, "")
349	require.NoError(t, err)
350	defer MustClose(t, certFile)
351	t.Cleanup(func() {
352		require.NoError(t, os.Remove(certFile.Name()))
353	})
354
355	// create chained PEM file with CA and entity cert
356	for _, cert := range [][]byte{entityCert, caCert} {
357		require.NoError(t,
358			pem.Encode(certFile, &pem.Block{
359				Type:  "CERTIFICATE",
360				Bytes: cert,
361			}),
362		)
363	}
364
365	keyFile, err := os.CreateTemp(testDirectory, "")
366	require.NoError(t, err)
367	defer MustClose(t, keyFile)
368	t.Cleanup(func() {
369		require.NoError(t, os.Remove(keyFile.Name()))
370	})
371
372	entityKeyBytes, err := x509.MarshalECPrivateKey(entityKey)
373	require.NoError(t, err)
374
375	require.NoError(t,
376		pem.Encode(keyFile, &pem.Block{
377			Type:  "ECDSA PRIVATE KEY",
378			Bytes: entityKeyBytes,
379		}),
380	)
381
382	return certFile.Name(), keyFile.Name()
383}
384