1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build !go1.7
6
7package context
8
9import (
10	"fmt"
11	"math/rand"
12	"runtime"
13	"strings"
14	"sync"
15	"testing"
16	"time"
17)
18
19// otherContext is a Context that's not one of the types defined in context.go.
20// This lets us test code paths that differ based on the underlying type of the
21// Context.
22type otherContext struct {
23	Context
24}
25
26func TestBackground(t *testing.T) {
27	c := Background()
28	if c == nil {
29		t.Fatalf("Background returned nil")
30	}
31	select {
32	case x := <-c.Done():
33		t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
34	default:
35	}
36	if got, want := fmt.Sprint(c), "context.Background"; got != want {
37		t.Errorf("Background().String() = %q want %q", got, want)
38	}
39}
40
41func TestTODO(t *testing.T) {
42	c := TODO()
43	if c == nil {
44		t.Fatalf("TODO returned nil")
45	}
46	select {
47	case x := <-c.Done():
48		t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
49	default:
50	}
51	if got, want := fmt.Sprint(c), "context.TODO"; got != want {
52		t.Errorf("TODO().String() = %q want %q", got, want)
53	}
54}
55
56func TestWithCancel(t *testing.T) {
57	c1, cancel := WithCancel(Background())
58
59	if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want {
60		t.Errorf("c1.String() = %q want %q", got, want)
61	}
62
63	o := otherContext{c1}
64	c2, _ := WithCancel(o)
65	contexts := []Context{c1, o, c2}
66
67	for i, c := range contexts {
68		if d := c.Done(); d == nil {
69			t.Errorf("c[%d].Done() == %v want non-nil", i, d)
70		}
71		if e := c.Err(); e != nil {
72			t.Errorf("c[%d].Err() == %v want nil", i, e)
73		}
74
75		select {
76		case x := <-c.Done():
77			t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
78		default:
79		}
80	}
81
82	cancel()
83	time.Sleep(100 * time.Millisecond) // let cancelation propagate
84
85	for i, c := range contexts {
86		select {
87		case <-c.Done():
88		default:
89			t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i)
90		}
91		if e := c.Err(); e != Canceled {
92			t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled)
93		}
94	}
95}
96
97func TestParentFinishesChild(t *testing.T) {
98	// Context tree:
99	// parent -> cancelChild
100	// parent -> valueChild -> timerChild
101	parent, cancel := WithCancel(Background())
102	cancelChild, stop := WithCancel(parent)
103	defer stop()
104	valueChild := WithValue(parent, "key", "value")
105	timerChild, stop := WithTimeout(valueChild, 10000*time.Hour)
106	defer stop()
107
108	select {
109	case x := <-parent.Done():
110		t.Errorf("<-parent.Done() == %v want nothing (it should block)", x)
111	case x := <-cancelChild.Done():
112		t.Errorf("<-cancelChild.Done() == %v want nothing (it should block)", x)
113	case x := <-timerChild.Done():
114		t.Errorf("<-timerChild.Done() == %v want nothing (it should block)", x)
115	case x := <-valueChild.Done():
116		t.Errorf("<-valueChild.Done() == %v want nothing (it should block)", x)
117	default:
118	}
119
120	// The parent's children should contain the two cancelable children.
121	pc := parent.(*cancelCtx)
122	cc := cancelChild.(*cancelCtx)
123	tc := timerChild.(*timerCtx)
124	pc.mu.Lock()
125	if len(pc.children) != 2 || !pc.children[cc] || !pc.children[tc] {
126		t.Errorf("bad linkage: pc.children = %v, want %v and %v",
127			pc.children, cc, tc)
128	}
129	pc.mu.Unlock()
130
131	if p, ok := parentCancelCtx(cc.Context); !ok || p != pc {
132		t.Errorf("bad linkage: parentCancelCtx(cancelChild.Context) = %v, %v want %v, true", p, ok, pc)
133	}
134	if p, ok := parentCancelCtx(tc.Context); !ok || p != pc {
135		t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc)
136	}
137
138	cancel()
139
140	pc.mu.Lock()
141	if len(pc.children) != 0 {
142		t.Errorf("pc.cancel didn't clear pc.children = %v", pc.children)
143	}
144	pc.mu.Unlock()
145
146	// parent and children should all be finished.
147	check := func(ctx Context, name string) {
148		select {
149		case <-ctx.Done():
150		default:
151			t.Errorf("<-%s.Done() blocked, but shouldn't have", name)
152		}
153		if e := ctx.Err(); e != Canceled {
154			t.Errorf("%s.Err() == %v want %v", name, e, Canceled)
155		}
156	}
157	check(parent, "parent")
158	check(cancelChild, "cancelChild")
159	check(valueChild, "valueChild")
160	check(timerChild, "timerChild")
161
162	// WithCancel should return a canceled context on a canceled parent.
163	precanceledChild := WithValue(parent, "key", "value")
164	select {
165	case <-precanceledChild.Done():
166	default:
167		t.Errorf("<-precanceledChild.Done() blocked, but shouldn't have")
168	}
169	if e := precanceledChild.Err(); e != Canceled {
170		t.Errorf("precanceledChild.Err() == %v want %v", e, Canceled)
171	}
172}
173
174func TestChildFinishesFirst(t *testing.T) {
175	cancelable, stop := WithCancel(Background())
176	defer stop()
177	for _, parent := range []Context{Background(), cancelable} {
178		child, cancel := WithCancel(parent)
179
180		select {
181		case x := <-parent.Done():
182			t.Errorf("<-parent.Done() == %v want nothing (it should block)", x)
183		case x := <-child.Done():
184			t.Errorf("<-child.Done() == %v want nothing (it should block)", x)
185		default:
186		}
187
188		cc := child.(*cancelCtx)
189		pc, pcok := parent.(*cancelCtx) // pcok == false when parent == Background()
190		if p, ok := parentCancelCtx(cc.Context); ok != pcok || (ok && pc != p) {
191			t.Errorf("bad linkage: parentCancelCtx(cc.Context) = %v, %v want %v, %v", p, ok, pc, pcok)
192		}
193
194		if pcok {
195			pc.mu.Lock()
196			if len(pc.children) != 1 || !pc.children[cc] {
197				t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc)
198			}
199			pc.mu.Unlock()
200		}
201
202		cancel()
203
204		if pcok {
205			pc.mu.Lock()
206			if len(pc.children) != 0 {
207				t.Errorf("child's cancel didn't remove self from pc.children = %v", pc.children)
208			}
209			pc.mu.Unlock()
210		}
211
212		// child should be finished.
213		select {
214		case <-child.Done():
215		default:
216			t.Errorf("<-child.Done() blocked, but shouldn't have")
217		}
218		if e := child.Err(); e != Canceled {
219			t.Errorf("child.Err() == %v want %v", e, Canceled)
220		}
221
222		// parent should not be finished.
223		select {
224		case x := <-parent.Done():
225			t.Errorf("<-parent.Done() == %v want nothing (it should block)", x)
226		default:
227		}
228		if e := parent.Err(); e != nil {
229			t.Errorf("parent.Err() == %v want nil", e)
230		}
231	}
232}
233
234func testDeadline(c Context, wait time.Duration, t *testing.T) {
235	select {
236	case <-time.After(wait):
237		t.Fatalf("context should have timed out")
238	case <-c.Done():
239	}
240	if e := c.Err(); e != DeadlineExceeded {
241		t.Errorf("c.Err() == %v want %v", e, DeadlineExceeded)
242	}
243}
244
245func TestDeadline(t *testing.T) {
246	t.Parallel()
247	const timeUnit = 500 * time.Millisecond
248	c, _ := WithDeadline(Background(), time.Now().Add(1*timeUnit))
249	if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
250		t.Errorf("c.String() = %q want prefix %q", got, prefix)
251	}
252	testDeadline(c, 2*timeUnit, t)
253
254	c, _ = WithDeadline(Background(), time.Now().Add(1*timeUnit))
255	o := otherContext{c}
256	testDeadline(o, 2*timeUnit, t)
257
258	c, _ = WithDeadline(Background(), time.Now().Add(1*timeUnit))
259	o = otherContext{c}
260	c, _ = WithDeadline(o, time.Now().Add(3*timeUnit))
261	testDeadline(c, 2*timeUnit, t)
262}
263
264func TestTimeout(t *testing.T) {
265	t.Parallel()
266	const timeUnit = 500 * time.Millisecond
267	c, _ := WithTimeout(Background(), 1*timeUnit)
268	if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
269		t.Errorf("c.String() = %q want prefix %q", got, prefix)
270	}
271	testDeadline(c, 2*timeUnit, t)
272
273	c, _ = WithTimeout(Background(), 1*timeUnit)
274	o := otherContext{c}
275	testDeadline(o, 2*timeUnit, t)
276
277	c, _ = WithTimeout(Background(), 1*timeUnit)
278	o = otherContext{c}
279	c, _ = WithTimeout(o, 3*timeUnit)
280	testDeadline(c, 2*timeUnit, t)
281}
282
283func TestCanceledTimeout(t *testing.T) {
284	t.Parallel()
285	const timeUnit = 500 * time.Millisecond
286	c, _ := WithTimeout(Background(), 2*timeUnit)
287	o := otherContext{c}
288	c, cancel := WithTimeout(o, 4*timeUnit)
289	cancel()
290	time.Sleep(1 * timeUnit) // let cancelation propagate
291	select {
292	case <-c.Done():
293	default:
294		t.Errorf("<-c.Done() blocked, but shouldn't have")
295	}
296	if e := c.Err(); e != Canceled {
297		t.Errorf("c.Err() == %v want %v", e, Canceled)
298	}
299}
300
301type key1 int
302type key2 int
303
304var k1 = key1(1)
305var k2 = key2(1) // same int as k1, different type
306var k3 = key2(3) // same type as k2, different int
307
308func TestValues(t *testing.T) {
309	check := func(c Context, nm, v1, v2, v3 string) {
310		if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 {
311			t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0)
312		}
313		if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 {
314			t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0)
315		}
316		if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 {
317			t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0)
318		}
319	}
320
321	c0 := Background()
322	check(c0, "c0", "", "", "")
323
324	c1 := WithValue(Background(), k1, "c1k1")
325	check(c1, "c1", "c1k1", "", "")
326
327	if got, want := fmt.Sprint(c1), `context.Background.WithValue(1, "c1k1")`; got != want {
328		t.Errorf("c.String() = %q want %q", got, want)
329	}
330
331	c2 := WithValue(c1, k2, "c2k2")
332	check(c2, "c2", "c1k1", "c2k2", "")
333
334	c3 := WithValue(c2, k3, "c3k3")
335	check(c3, "c2", "c1k1", "c2k2", "c3k3")
336
337	c4 := WithValue(c3, k1, nil)
338	check(c4, "c4", "", "c2k2", "c3k3")
339
340	o0 := otherContext{Background()}
341	check(o0, "o0", "", "", "")
342
343	o1 := otherContext{WithValue(Background(), k1, "c1k1")}
344	check(o1, "o1", "c1k1", "", "")
345
346	o2 := WithValue(o1, k2, "o2k2")
347	check(o2, "o2", "c1k1", "o2k2", "")
348
349	o3 := otherContext{c4}
350	check(o3, "o3", "", "c2k2", "c3k3")
351
352	o4 := WithValue(o3, k3, nil)
353	check(o4, "o4", "", "c2k2", "")
354}
355
356func TestAllocs(t *testing.T) {
357	bg := Background()
358	for _, test := range []struct {
359		desc       string
360		f          func()
361		limit      float64
362		gccgoLimit float64
363	}{
364		{
365			desc:       "Background()",
366			f:          func() { Background() },
367			limit:      0,
368			gccgoLimit: 0,
369		},
370		{
371			desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1),
372			f: func() {
373				c := WithValue(bg, k1, nil)
374				c.Value(k1)
375			},
376			limit:      3,
377			gccgoLimit: 3,
378		},
379		{
380			desc: "WithTimeout(bg, 15*time.Millisecond)",
381			f: func() {
382				c, _ := WithTimeout(bg, 15*time.Millisecond)
383				<-c.Done()
384			},
385			limit:      8,
386			gccgoLimit: 16,
387		},
388		{
389			desc: "WithCancel(bg)",
390			f: func() {
391				c, cancel := WithCancel(bg)
392				cancel()
393				<-c.Done()
394			},
395			limit:      5,
396			gccgoLimit: 8,
397		},
398		{
399			desc: "WithTimeout(bg, 100*time.Millisecond)",
400			f: func() {
401				c, cancel := WithTimeout(bg, 100*time.Millisecond)
402				cancel()
403				<-c.Done()
404			},
405			limit:      8,
406			gccgoLimit: 25,
407		},
408	} {
409		limit := test.limit
410		if runtime.Compiler == "gccgo" {
411			// gccgo does not yet do escape analysis.
412			// TODO(iant): Remove this when gccgo does do escape analysis.
413			limit = test.gccgoLimit
414		}
415		if n := testing.AllocsPerRun(100, test.f); n > limit {
416			t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit))
417		}
418	}
419}
420
421func TestSimultaneousCancels(t *testing.T) {
422	root, cancel := WithCancel(Background())
423	m := map[Context]CancelFunc{root: cancel}
424	q := []Context{root}
425	// Create a tree of contexts.
426	for len(q) != 0 && len(m) < 100 {
427		parent := q[0]
428		q = q[1:]
429		for i := 0; i < 4; i++ {
430			ctx, cancel := WithCancel(parent)
431			m[ctx] = cancel
432			q = append(q, ctx)
433		}
434	}
435	// Start all the cancels in a random order.
436	var wg sync.WaitGroup
437	wg.Add(len(m))
438	for _, cancel := range m {
439		go func(cancel CancelFunc) {
440			cancel()
441			wg.Done()
442		}(cancel)
443	}
444	// Wait on all the contexts in a random order.
445	for ctx := range m {
446		select {
447		case <-ctx.Done():
448		case <-time.After(1 * time.Second):
449			buf := make([]byte, 10<<10)
450			n := runtime.Stack(buf, true)
451			t.Fatalf("timed out waiting for <-ctx.Done(); stacks:\n%s", buf[:n])
452		}
453	}
454	// Wait for all the cancel functions to return.
455	done := make(chan struct{})
456	go func() {
457		wg.Wait()
458		close(done)
459	}()
460	select {
461	case <-done:
462	case <-time.After(1 * time.Second):
463		buf := make([]byte, 10<<10)
464		n := runtime.Stack(buf, true)
465		t.Fatalf("timed out waiting for cancel functions; stacks:\n%s", buf[:n])
466	}
467}
468
469func TestInterlockedCancels(t *testing.T) {
470	parent, cancelParent := WithCancel(Background())
471	child, cancelChild := WithCancel(parent)
472	go func() {
473		parent.Done()
474		cancelChild()
475	}()
476	cancelParent()
477	select {
478	case <-child.Done():
479	case <-time.After(1 * time.Second):
480		buf := make([]byte, 10<<10)
481		n := runtime.Stack(buf, true)
482		t.Fatalf("timed out waiting for child.Done(); stacks:\n%s", buf[:n])
483	}
484}
485
486func TestLayersCancel(t *testing.T) {
487	testLayers(t, time.Now().UnixNano(), false)
488}
489
490func TestLayersTimeout(t *testing.T) {
491	testLayers(t, time.Now().UnixNano(), true)
492}
493
494func testLayers(t *testing.T, seed int64, testTimeout bool) {
495	rand.Seed(seed)
496	errorf := func(format string, a ...interface{}) {
497		t.Errorf(fmt.Sprintf("seed=%d: %s", seed, format), a...)
498	}
499	const (
500		timeout   = 200 * time.Millisecond
501		minLayers = 30
502	)
503	type value int
504	var (
505		vals      []*value
506		cancels   []CancelFunc
507		numTimers int
508		ctx       = Background()
509	)
510	for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ {
511		switch rand.Intn(3) {
512		case 0:
513			v := new(value)
514			ctx = WithValue(ctx, v, v)
515			vals = append(vals, v)
516		case 1:
517			var cancel CancelFunc
518			ctx, cancel = WithCancel(ctx)
519			cancels = append(cancels, cancel)
520		case 2:
521			var cancel CancelFunc
522			ctx, cancel = WithTimeout(ctx, timeout)
523			cancels = append(cancels, cancel)
524			numTimers++
525		}
526	}
527	checkValues := func(when string) {
528		for _, key := range vals {
529			if val := ctx.Value(key).(*value); key != val {
530				errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key)
531			}
532		}
533	}
534	select {
535	case <-ctx.Done():
536		errorf("ctx should not be canceled yet")
537	default:
538	}
539	if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) {
540		t.Errorf("ctx.String() = %q want prefix %q", s, prefix)
541	}
542	t.Log(ctx)
543	checkValues("before cancel")
544	if testTimeout {
545		select {
546		case <-ctx.Done():
547		case <-time.After(timeout + 100*time.Millisecond):
548			errorf("ctx should have timed out")
549		}
550		checkValues("after timeout")
551	} else {
552		cancel := cancels[rand.Intn(len(cancels))]
553		cancel()
554		select {
555		case <-ctx.Done():
556		default:
557			errorf("ctx should be canceled")
558		}
559		checkValues("after cancel")
560	}
561}
562
563func TestCancelRemoves(t *testing.T) {
564	checkChildren := func(when string, ctx Context, want int) {
565		if got := len(ctx.(*cancelCtx).children); got != want {
566			t.Errorf("%s: context has %d children, want %d", when, got, want)
567		}
568	}
569
570	ctx, _ := WithCancel(Background())
571	checkChildren("after creation", ctx, 0)
572	_, cancel := WithCancel(ctx)
573	checkChildren("with WithCancel child ", ctx, 1)
574	cancel()
575	checkChildren("after cancelling WithCancel child", ctx, 0)
576
577	ctx, _ = WithCancel(Background())
578	checkChildren("after creation", ctx, 0)
579	_, cancel = WithTimeout(ctx, 60*time.Minute)
580	checkChildren("with WithTimeout child ", ctx, 1)
581	cancel()
582	checkChildren("after cancelling WithTimeout child", ctx, 0)
583}
584