Source file src/context/x_test.go

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package context_test
     6  
     7  import (
     8  	. "context"
     9  	"errors"
    10  	"fmt"
    11  	"math/rand"
    12  	"runtime"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  )
    18  
    19  // Each XTestFoo in context_test.go must be called from a TestFoo here to run.
    20  func TestParentFinishesChild(t *testing.T) {
    21  	XTestParentFinishesChild(t) // uses unexported context types
    22  }
    23  func TestChildFinishesFirst(t *testing.T) {
    24  	XTestChildFinishesFirst(t) // uses unexported context types
    25  }
    26  func TestCancelRemoves(t *testing.T) {
    27  	XTestCancelRemoves(t) // uses unexported context types
    28  }
    29  func TestCustomContextGoroutines(t *testing.T) {
    30  	XTestCustomContextGoroutines(t) // reads the context.goroutines counter
    31  }
    32  
    33  // The following are regular tests in package context_test.
    34  
    35  // otherContext is a Context that's not one of the types defined in context.go.
    36  // This lets us test code paths that differ based on the underlying type of the
    37  // Context.
    38  type otherContext struct {
    39  	Context
    40  }
    41  
    42  const (
    43  	shortDuration    = 1 * time.Millisecond // a reasonable duration to block in a test
    44  	veryLongDuration = 1000 * time.Hour     // an arbitrary upper bound on the test's running time
    45  )
    46  
    47  // quiescent returns an arbitrary duration by which the program should have
    48  // completed any remaining work and reached a steady (idle) state.
    49  func quiescent(t *testing.T) time.Duration {
    50  	deadline, ok := t.Deadline()
    51  	if !ok {
    52  		return 5 * time.Second
    53  	}
    54  
    55  	const arbitraryCleanupMargin = 1 * time.Second
    56  	return time.Until(deadline) - arbitraryCleanupMargin
    57  }
    58  func TestBackground(t *testing.T) {
    59  	c := Background()
    60  	if c == nil {
    61  		t.Fatalf("Background returned nil")
    62  	}
    63  	select {
    64  	case x := <-c.Done():
    65  		t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
    66  	default:
    67  	}
    68  	if got, want := fmt.Sprint(c), "context.Background"; got != want {
    69  		t.Errorf("Background().String() = %q want %q", got, want)
    70  	}
    71  }
    72  
    73  func TestTODO(t *testing.T) {
    74  	c := TODO()
    75  	if c == nil {
    76  		t.Fatalf("TODO returned nil")
    77  	}
    78  	select {
    79  	case x := <-c.Done():
    80  		t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
    81  	default:
    82  	}
    83  	if got, want := fmt.Sprint(c), "context.TODO"; got != want {
    84  		t.Errorf("TODO().String() = %q want %q", got, want)
    85  	}
    86  }
    87  
    88  func TestWithCancel(t *testing.T) {
    89  	c1, cancel := WithCancel(Background())
    90  
    91  	if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want {
    92  		t.Errorf("c1.String() = %q want %q", got, want)
    93  	}
    94  
    95  	o := otherContext{c1}
    96  	c2, _ := WithCancel(o)
    97  	contexts := []Context{c1, o, c2}
    98  
    99  	for i, c := range contexts {
   100  		if d := c.Done(); d == nil {
   101  			t.Errorf("c[%d].Done() == %v want non-nil", i, d)
   102  		}
   103  		if e := c.Err(); e != nil {
   104  			t.Errorf("c[%d].Err() == %v want nil", i, e)
   105  		}
   106  
   107  		select {
   108  		case x := <-c.Done():
   109  			t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
   110  		default:
   111  		}
   112  	}
   113  
   114  	cancel() // Should propagate synchronously.
   115  	for i, c := range contexts {
   116  		select {
   117  		case <-c.Done():
   118  		default:
   119  			t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i)
   120  		}
   121  		if e := c.Err(); e != Canceled {
   122  			t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled)
   123  		}
   124  	}
   125  }
   126  
   127  func testDeadline(c Context, name string, t *testing.T) {
   128  	t.Helper()
   129  	d := quiescent(t)
   130  	timer := time.NewTimer(d)
   131  	defer timer.Stop()
   132  	select {
   133  	case <-timer.C:
   134  		t.Fatalf("%s: context not timed out after %v", name, d)
   135  	case <-c.Done():
   136  	}
   137  	if e := c.Err(); e != DeadlineExceeded {
   138  		t.Errorf("%s: c.Err() == %v; want %v", name, e, DeadlineExceeded)
   139  	}
   140  }
   141  
   142  func TestDeadline(t *testing.T) {
   143  	t.Parallel()
   144  
   145  	c, _ := WithDeadline(Background(), time.Now().Add(shortDuration))
   146  	if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
   147  		t.Errorf("c.String() = %q want prefix %q", got, prefix)
   148  	}
   149  	testDeadline(c, "WithDeadline", t)
   150  
   151  	c, _ = WithDeadline(Background(), time.Now().Add(shortDuration))
   152  	o := otherContext{c}
   153  	testDeadline(o, "WithDeadline+otherContext", t)
   154  
   155  	c, _ = WithDeadline(Background(), time.Now().Add(shortDuration))
   156  	o = otherContext{c}
   157  	c, _ = WithDeadline(o, time.Now().Add(veryLongDuration))
   158  	testDeadline(c, "WithDeadline+otherContext+WithDeadline", t)
   159  
   160  	c, _ = WithDeadline(Background(), time.Now().Add(-shortDuration))
   161  	testDeadline(c, "WithDeadline+inthepast", t)
   162  
   163  	c, _ = WithDeadline(Background(), time.Now())
   164  	testDeadline(c, "WithDeadline+now", t)
   165  }
   166  
   167  func TestTimeout(t *testing.T) {
   168  	t.Parallel()
   169  
   170  	c, _ := WithTimeout(Background(), shortDuration)
   171  	if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
   172  		t.Errorf("c.String() = %q want prefix %q", got, prefix)
   173  	}
   174  	testDeadline(c, "WithTimeout", t)
   175  
   176  	c, _ = WithTimeout(Background(), shortDuration)
   177  	o := otherContext{c}
   178  	testDeadline(o, "WithTimeout+otherContext", t)
   179  
   180  	c, _ = WithTimeout(Background(), shortDuration)
   181  	o = otherContext{c}
   182  	c, _ = WithTimeout(o, veryLongDuration)
   183  	testDeadline(c, "WithTimeout+otherContext+WithTimeout", t)
   184  }
   185  
   186  func TestCanceledTimeout(t *testing.T) {
   187  	c, _ := WithTimeout(Background(), time.Second)
   188  	o := otherContext{c}
   189  	c, cancel := WithTimeout(o, veryLongDuration)
   190  	cancel() // Should propagate synchronously.
   191  	select {
   192  	case <-c.Done():
   193  	default:
   194  		t.Errorf("<-c.Done() blocked, but shouldn't have")
   195  	}
   196  	if e := c.Err(); e != Canceled {
   197  		t.Errorf("c.Err() == %v want %v", e, Canceled)
   198  	}
   199  }
   200  
   201  type key1 int
   202  type key2 int
   203  
   204  var k1 = key1(1)
   205  var k2 = key2(1) // same int as k1, different type
   206  var k3 = key2(3) // same type as k2, different int
   207  
   208  func TestValues(t *testing.T) {
   209  	check := func(c Context, nm, v1, v2, v3 string) {
   210  		if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 {
   211  			t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0)
   212  		}
   213  		if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 {
   214  			t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0)
   215  		}
   216  		if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 {
   217  			t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0)
   218  		}
   219  	}
   220  
   221  	c0 := Background()
   222  	check(c0, "c0", "", "", "")
   223  
   224  	c1 := WithValue(Background(), k1, "c1k1")
   225  	check(c1, "c1", "c1k1", "", "")
   226  
   227  	if got, want := fmt.Sprint(c1), `context.Background.WithValue(type context_test.key1, val c1k1)`; got != want {
   228  		t.Errorf("c.String() = %q want %q", got, want)
   229  	}
   230  
   231  	c2 := WithValue(c1, k2, "c2k2")
   232  	check(c2, "c2", "c1k1", "c2k2", "")
   233  
   234  	c3 := WithValue(c2, k3, "c3k3")
   235  	check(c3, "c2", "c1k1", "c2k2", "c3k3")
   236  
   237  	c4 := WithValue(c3, k1, nil)
   238  	check(c4, "c4", "", "c2k2", "c3k3")
   239  
   240  	o0 := otherContext{Background()}
   241  	check(o0, "o0", "", "", "")
   242  
   243  	o1 := otherContext{WithValue(Background(), k1, "c1k1")}
   244  	check(o1, "o1", "c1k1", "", "")
   245  
   246  	o2 := WithValue(o1, k2, "o2k2")
   247  	check(o2, "o2", "c1k1", "o2k2", "")
   248  
   249  	o3 := otherContext{c4}
   250  	check(o3, "o3", "", "c2k2", "c3k3")
   251  
   252  	o4 := WithValue(o3, k3, nil)
   253  	check(o4, "o4", "", "c2k2", "")
   254  }
   255  
   256  func TestAllocs(t *testing.T) {
   257  	bg := Background()
   258  	for _, test := range []struct {
   259  		desc       string
   260  		f          func()
   261  		limit      float64
   262  		gccgoLimit float64
   263  	}{
   264  		{
   265  			desc:       "Background()",
   266  			f:          func() { Background() },
   267  			limit:      0,
   268  			gccgoLimit: 0,
   269  		},
   270  		{
   271  			desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1),
   272  			f: func() {
   273  				c := WithValue(bg, k1, nil)
   274  				c.Value(k1)
   275  			},
   276  			limit:      3,
   277  			gccgoLimit: 3,
   278  		},
   279  		{
   280  			desc: "WithTimeout(bg, 1*time.Nanosecond)",
   281  			f: func() {
   282  				c, _ := WithTimeout(bg, 1*time.Nanosecond)
   283  				<-c.Done()
   284  			},
   285  			limit:      12,
   286  			gccgoLimit: 15,
   287  		},
   288  		{
   289  			desc: "WithCancel(bg)",
   290  			f: func() {
   291  				c, cancel := WithCancel(bg)
   292  				cancel()
   293  				<-c.Done()
   294  			},
   295  			limit:      5,
   296  			gccgoLimit: 8,
   297  		},
   298  		{
   299  			desc: "WithTimeout(bg, 5*time.Millisecond)",
   300  			f: func() {
   301  				c, cancel := WithTimeout(bg, 5*time.Millisecond)
   302  				cancel()
   303  				<-c.Done()
   304  			},
   305  			limit:      8,
   306  			gccgoLimit: 25,
   307  		},
   308  	} {
   309  		limit := test.limit
   310  		if runtime.Compiler == "gccgo" {
   311  			// gccgo does not yet do escape analysis.
   312  			// TODO(iant): Remove this when gccgo does do escape analysis.
   313  			limit = test.gccgoLimit
   314  		}
   315  		numRuns := 100
   316  		if testing.Short() {
   317  			numRuns = 10
   318  		}
   319  		if n := testing.AllocsPerRun(numRuns, test.f); n > limit {
   320  			t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit))
   321  		}
   322  	}
   323  }
   324  
   325  func TestSimultaneousCancels(t *testing.T) {
   326  	root, cancel := WithCancel(Background())
   327  	m := map[Context]CancelFunc{root: cancel}
   328  	q := []Context{root}
   329  	// Create a tree of contexts.
   330  	for len(q) != 0 && len(m) < 100 {
   331  		parent := q[0]
   332  		q = q[1:]
   333  		for i := 0; i < 4; i++ {
   334  			ctx, cancel := WithCancel(parent)
   335  			m[ctx] = cancel
   336  			q = append(q, ctx)
   337  		}
   338  	}
   339  	// Start all the cancels in a random order.
   340  	var wg sync.WaitGroup
   341  	wg.Add(len(m))
   342  	for _, cancel := range m {
   343  		go func(cancel CancelFunc) {
   344  			cancel()
   345  			wg.Done()
   346  		}(cancel)
   347  	}
   348  
   349  	d := quiescent(t)
   350  	stuck := make(chan struct{})
   351  	timer := time.AfterFunc(d, func() { close(stuck) })
   352  	defer timer.Stop()
   353  
   354  	// Wait on all the contexts in a random order.
   355  	for ctx := range m {
   356  		select {
   357  		case <-ctx.Done():
   358  		case <-stuck:
   359  			buf := make([]byte, 10<<10)
   360  			n := runtime.Stack(buf, true)
   361  			t.Fatalf("timed out after %v waiting for <-ctx.Done(); stacks:\n%s", d, buf[:n])
   362  		}
   363  	}
   364  	// Wait for all the cancel functions to return.
   365  	done := make(chan struct{})
   366  	go func() {
   367  		wg.Wait()
   368  		close(done)
   369  	}()
   370  	select {
   371  	case <-done:
   372  	case <-stuck:
   373  		buf := make([]byte, 10<<10)
   374  		n := runtime.Stack(buf, true)
   375  		t.Fatalf("timed out after %v waiting for cancel functions; stacks:\n%s", d, buf[:n])
   376  	}
   377  }
   378  
   379  func TestInterlockedCancels(t *testing.T) {
   380  	parent, cancelParent := WithCancel(Background())
   381  	child, cancelChild := WithCancel(parent)
   382  	go func() {
   383  		<-parent.Done()
   384  		cancelChild()
   385  	}()
   386  	cancelParent()
   387  	d := quiescent(t)
   388  	timer := time.NewTimer(d)
   389  	defer timer.Stop()
   390  	select {
   391  	case <-child.Done():
   392  	case <-timer.C:
   393  		buf := make([]byte, 10<<10)
   394  		n := runtime.Stack(buf, true)
   395  		t.Fatalf("timed out after %v waiting for child.Done(); stacks:\n%s", d, buf[:n])
   396  	}
   397  }
   398  
   399  func TestLayersCancel(t *testing.T) {
   400  	testLayers(t, time.Now().UnixNano(), false)
   401  }
   402  
   403  func TestLayersTimeout(t *testing.T) {
   404  	testLayers(t, time.Now().UnixNano(), true)
   405  }
   406  
   407  func testLayers(t *testing.T, seed int64, testTimeout bool) {
   408  	t.Parallel()
   409  
   410  	r := rand.New(rand.NewSource(seed))
   411  	prefix := fmt.Sprintf("seed=%d", seed)
   412  	errorf := func(format string, a ...any) {
   413  		t.Errorf(prefix+format, a...)
   414  	}
   415  	const (
   416  		minLayers = 30
   417  	)
   418  	type value int
   419  	var (
   420  		vals      []*value
   421  		cancels   []CancelFunc
   422  		numTimers int
   423  		ctx       = Background()
   424  	)
   425  	for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ {
   426  		switch r.Intn(3) {
   427  		case 0:
   428  			v := new(value)
   429  			ctx = WithValue(ctx, v, v)
   430  			vals = append(vals, v)
   431  		case 1:
   432  			var cancel CancelFunc
   433  			ctx, cancel = WithCancel(ctx)
   434  			cancels = append(cancels, cancel)
   435  		case 2:
   436  			var cancel CancelFunc
   437  			d := veryLongDuration
   438  			if testTimeout {
   439  				d = shortDuration
   440  			}
   441  			ctx, cancel = WithTimeout(ctx, d)
   442  			cancels = append(cancels, cancel)
   443  			numTimers++
   444  		}
   445  	}
   446  	checkValues := func(when string) {
   447  		for _, key := range vals {
   448  			if val := ctx.Value(key).(*value); key != val {
   449  				errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key)
   450  			}
   451  		}
   452  	}
   453  	if !testTimeout {
   454  		select {
   455  		case <-ctx.Done():
   456  			errorf("ctx should not be canceled yet")
   457  		default:
   458  		}
   459  	}
   460  	if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) {
   461  		t.Errorf("ctx.String() = %q want prefix %q", s, prefix)
   462  	}
   463  	t.Log(ctx)
   464  	checkValues("before cancel")
   465  	if testTimeout {
   466  		d := quiescent(t)
   467  		timer := time.NewTimer(d)
   468  		defer timer.Stop()
   469  		select {
   470  		case <-ctx.Done():
   471  		case <-timer.C:
   472  			errorf("ctx should have timed out after %v", d)
   473  		}
   474  		checkValues("after timeout")
   475  	} else {
   476  		cancel := cancels[r.Intn(len(cancels))]
   477  		cancel()
   478  		select {
   479  		case <-ctx.Done():
   480  		default:
   481  			errorf("ctx should be canceled")
   482  		}
   483  		checkValues("after cancel")
   484  	}
   485  }
   486  
   487  func TestWithCancelCanceledParent(t *testing.T) {
   488  	parent, pcancel := WithCancelCause(Background())
   489  	cause := fmt.Errorf("Because!")
   490  	pcancel(cause)
   491  
   492  	c, _ := WithCancel(parent)
   493  	select {
   494  	case <-c.Done():
   495  	default:
   496  		t.Errorf("child not done immediately upon construction")
   497  	}
   498  	if got, want := c.Err(), Canceled; got != want {
   499  		t.Errorf("child not canceled; got = %v, want = %v", got, want)
   500  	}
   501  	if got, want := Cause(c), cause; got != want {
   502  		t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
   503  	}
   504  }
   505  
   506  func TestWithCancelSimultaneouslyCanceledParent(t *testing.T) {
   507  	// Cancel the parent goroutine concurrently with creating a child.
   508  	for i := 0; i < 100; i++ {
   509  		parent, pcancel := WithCancelCause(Background())
   510  		cause := fmt.Errorf("Because!")
   511  		go pcancel(cause)
   512  
   513  		c, _ := WithCancel(parent)
   514  		<-c.Done()
   515  		if got, want := c.Err(), Canceled; got != want {
   516  			t.Errorf("child not canceled; got = %v, want = %v", got, want)
   517  		}
   518  		if got, want := Cause(c), cause; got != want {
   519  			t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
   520  		}
   521  	}
   522  }
   523  
   524  func TestWithValueChecksKey(t *testing.T) {
   525  	panicVal := recoveredValue(func() { _ = WithValue(Background(), []byte("foo"), "bar") })
   526  	if panicVal == nil {
   527  		t.Error("expected panic")
   528  	}
   529  	panicVal = recoveredValue(func() { _ = WithValue(Background(), nil, "bar") })
   530  	if got, want := fmt.Sprint(panicVal), "nil key"; got != want {
   531  		t.Errorf("panic = %q; want %q", got, want)
   532  	}
   533  }
   534  
   535  func TestInvalidDerivedFail(t *testing.T) {
   536  	panicVal := recoveredValue(func() { _, _ = WithCancel(nil) })
   537  	if panicVal == nil {
   538  		t.Error("expected panic")
   539  	}
   540  	panicVal = recoveredValue(func() { _, _ = WithDeadline(nil, time.Now().Add(shortDuration)) })
   541  	if panicVal == nil {
   542  		t.Error("expected panic")
   543  	}
   544  	panicVal = recoveredValue(func() { _ = WithValue(nil, "foo", "bar") })
   545  	if panicVal == nil {
   546  		t.Error("expected panic")
   547  	}
   548  }
   549  
   550  func recoveredValue(fn func()) (v any) {
   551  	defer func() { v = recover() }()
   552  	fn()
   553  	return
   554  }
   555  
   556  func TestDeadlineExceededSupportsTimeout(t *testing.T) {
   557  	i, ok := DeadlineExceeded.(interface {
   558  		Timeout() bool
   559  	})
   560  	if !ok {
   561  		t.Fatal("DeadlineExceeded does not support Timeout interface")
   562  	}
   563  	if !i.Timeout() {
   564  		t.Fatal("wrong value for timeout")
   565  	}
   566  }
   567  func TestCause(t *testing.T) {
   568  	var (
   569  		forever       = 1e6 * time.Second
   570  		parentCause   = fmt.Errorf("parentCause")
   571  		childCause    = fmt.Errorf("childCause")
   572  		tooSlow       = fmt.Errorf("tooSlow")
   573  		finishedEarly = fmt.Errorf("finishedEarly")
   574  	)
   575  	for _, test := range []struct {
   576  		name  string
   577  		ctx   func() Context
   578  		err   error
   579  		cause error
   580  	}{
   581  		{
   582  			name:  "Background",
   583  			ctx:   Background,
   584  			err:   nil,
   585  			cause: nil,
   586  		},
   587  		{
   588  			name:  "TODO",
   589  			ctx:   TODO,
   590  			err:   nil,
   591  			cause: nil,
   592  		},
   593  		{
   594  			name: "WithCancel",
   595  			ctx: func() Context {
   596  				ctx, cancel := WithCancel(Background())
   597  				cancel()
   598  				return ctx
   599  			},
   600  			err:   Canceled,
   601  			cause: Canceled,
   602  		},
   603  		{
   604  			name: "WithCancelCause",
   605  			ctx: func() Context {
   606  				ctx, cancel := WithCancelCause(Background())
   607  				cancel(parentCause)
   608  				return ctx
   609  			},
   610  			err:   Canceled,
   611  			cause: parentCause,
   612  		},
   613  		{
   614  			name: "WithCancelCause nil",
   615  			ctx: func() Context {
   616  				ctx, cancel := WithCancelCause(Background())
   617  				cancel(nil)
   618  				return ctx
   619  			},
   620  			err:   Canceled,
   621  			cause: Canceled,
   622  		},
   623  		{
   624  			name: "WithCancelCause: parent cause before child",
   625  			ctx: func() Context {
   626  				ctx, cancelParent := WithCancelCause(Background())
   627  				ctx, cancelChild := WithCancelCause(ctx)
   628  				cancelParent(parentCause)
   629  				cancelChild(childCause)
   630  				return ctx
   631  			},
   632  			err:   Canceled,
   633  			cause: parentCause,
   634  		},
   635  		{
   636  			name: "WithCancelCause: parent cause after child",
   637  			ctx: func() Context {
   638  				ctx, cancelParent := WithCancelCause(Background())
   639  				ctx, cancelChild := WithCancelCause(ctx)
   640  				cancelChild(childCause)
   641  				cancelParent(parentCause)
   642  				return ctx
   643  			},
   644  			err:   Canceled,
   645  			cause: childCause,
   646  		},
   647  		{
   648  			name: "WithCancelCause: parent cause before nil",
   649  			ctx: func() Context {
   650  				ctx, cancelParent := WithCancelCause(Background())
   651  				ctx, cancelChild := WithCancel(ctx)
   652  				cancelParent(parentCause)
   653  				cancelChild()
   654  				return ctx
   655  			},
   656  			err:   Canceled,
   657  			cause: parentCause,
   658  		},
   659  		{
   660  			name: "WithCancelCause: parent cause after nil",
   661  			ctx: func() Context {
   662  				ctx, cancelParent := WithCancelCause(Background())
   663  				ctx, cancelChild := WithCancel(ctx)
   664  				cancelChild()
   665  				cancelParent(parentCause)
   666  				return ctx
   667  			},
   668  			err:   Canceled,
   669  			cause: Canceled,
   670  		},
   671  		{
   672  			name: "WithCancelCause: child cause after nil",
   673  			ctx: func() Context {
   674  				ctx, cancelParent := WithCancel(Background())
   675  				ctx, cancelChild := WithCancelCause(ctx)
   676  				cancelParent()
   677  				cancelChild(childCause)
   678  				return ctx
   679  			},
   680  			err:   Canceled,
   681  			cause: Canceled,
   682  		},
   683  		{
   684  			name: "WithCancelCause: child cause before nil",
   685  			ctx: func() Context {
   686  				ctx, cancelParent := WithCancel(Background())
   687  				ctx, cancelChild := WithCancelCause(ctx)
   688  				cancelChild(childCause)
   689  				cancelParent()
   690  				return ctx
   691  			},
   692  			err:   Canceled,
   693  			cause: childCause,
   694  		},
   695  		{
   696  			name: "WithTimeout",
   697  			ctx: func() Context {
   698  				ctx, cancel := WithTimeout(Background(), 0)
   699  				cancel()
   700  				return ctx
   701  			},
   702  			err:   DeadlineExceeded,
   703  			cause: DeadlineExceeded,
   704  		},
   705  		{
   706  			name: "WithTimeout canceled",
   707  			ctx: func() Context {
   708  				ctx, cancel := WithTimeout(Background(), forever)
   709  				cancel()
   710  				return ctx
   711  			},
   712  			err:   Canceled,
   713  			cause: Canceled,
   714  		},
   715  		{
   716  			name: "WithTimeoutCause",
   717  			ctx: func() Context {
   718  				ctx, cancel := WithTimeoutCause(Background(), 0, tooSlow)
   719  				cancel()
   720  				return ctx
   721  			},
   722  			err:   DeadlineExceeded,
   723  			cause: tooSlow,
   724  		},
   725  		{
   726  			name: "WithTimeoutCause canceled",
   727  			ctx: func() Context {
   728  				ctx, cancel := WithTimeoutCause(Background(), forever, tooSlow)
   729  				cancel()
   730  				return ctx
   731  			},
   732  			err:   Canceled,
   733  			cause: Canceled,
   734  		},
   735  		{
   736  			name: "WithTimeoutCause stacked",
   737  			ctx: func() Context {
   738  				ctx, cancel := WithCancelCause(Background())
   739  				ctx, _ = WithTimeoutCause(ctx, 0, tooSlow)
   740  				cancel(finishedEarly)
   741  				return ctx
   742  			},
   743  			err:   DeadlineExceeded,
   744  			cause: tooSlow,
   745  		},
   746  		{
   747  			name: "WithTimeoutCause stacked canceled",
   748  			ctx: func() Context {
   749  				ctx, cancel := WithCancelCause(Background())
   750  				ctx, _ = WithTimeoutCause(ctx, forever, tooSlow)
   751  				cancel(finishedEarly)
   752  				return ctx
   753  			},
   754  			err:   Canceled,
   755  			cause: finishedEarly,
   756  		},
   757  		{
   758  			name: "WithoutCancel",
   759  			ctx: func() Context {
   760  				return WithoutCancel(Background())
   761  			},
   762  			err:   nil,
   763  			cause: nil,
   764  		},
   765  		{
   766  			name: "WithoutCancel canceled",
   767  			ctx: func() Context {
   768  				ctx, cancel := WithCancelCause(Background())
   769  				ctx = WithoutCancel(ctx)
   770  				cancel(finishedEarly)
   771  				return ctx
   772  			},
   773  			err:   nil,
   774  			cause: nil,
   775  		},
   776  		{
   777  			name: "WithoutCancel timeout",
   778  			ctx: func() Context {
   779  				ctx, cancel := WithTimeoutCause(Background(), 0, tooSlow)
   780  				ctx = WithoutCancel(ctx)
   781  				cancel()
   782  				return ctx
   783  			},
   784  			err:   nil,
   785  			cause: nil,
   786  		},
   787  	} {
   788  		test := test
   789  		t.Run(test.name, func(t *testing.T) {
   790  			t.Parallel()
   791  			ctx := test.ctx()
   792  			if got, want := ctx.Err(), test.err; want != got {
   793  				t.Errorf("ctx.Err() = %v want %v", got, want)
   794  			}
   795  			if got, want := Cause(ctx), test.cause; want != got {
   796  				t.Errorf("Cause(ctx) = %v want %v", got, want)
   797  			}
   798  		})
   799  	}
   800  }
   801  
   802  func TestCauseRace(t *testing.T) {
   803  	cause := errors.New("TestCauseRace")
   804  	ctx, cancel := WithCancelCause(Background())
   805  	go func() {
   806  		cancel(cause)
   807  	}()
   808  	for {
   809  		// Poll Cause, rather than waiting for Done, to test that
   810  		// access to the underlying cause is synchronized properly.
   811  		if err := Cause(ctx); err != nil {
   812  			if err != cause {
   813  				t.Errorf("Cause returned %v, want %v", err, cause)
   814  			}
   815  			break
   816  		}
   817  		runtime.Gosched()
   818  	}
   819  }
   820  
   821  func TestWithoutCancel(t *testing.T) {
   822  	key, value := "key", "value"
   823  	ctx := WithValue(Background(), key, value)
   824  	ctx = WithoutCancel(ctx)
   825  	if d, ok := ctx.Deadline(); !d.IsZero() || ok != false {
   826  		t.Errorf("ctx.Deadline() = %v, %v want zero, false", d, ok)
   827  	}
   828  	if done := ctx.Done(); done != nil {
   829  		t.Errorf("ctx.Deadline() = %v want nil", done)
   830  	}
   831  	if err := ctx.Err(); err != nil {
   832  		t.Errorf("ctx.Err() = %v want nil", err)
   833  	}
   834  	if v := ctx.Value(key); v != value {
   835  		t.Errorf("ctx.Value(%q) = %q want %q", key, v, value)
   836  	}
   837  }
   838  
   839  type customDoneContext struct {
   840  	Context
   841  	donec chan struct{}
   842  }
   843  
   844  func (c *customDoneContext) Done() <-chan struct{} {
   845  	return c.donec
   846  }
   847  
   848  func TestCustomContextPropagation(t *testing.T) {
   849  	cause := errors.New("TestCustomContextPropagation")
   850  	donec := make(chan struct{})
   851  	ctx1, cancel1 := WithCancelCause(Background())
   852  	ctx2 := &customDoneContext{
   853  		Context: ctx1,
   854  		donec:   donec,
   855  	}
   856  	ctx3, cancel3 := WithCancel(ctx2)
   857  	defer cancel3()
   858  
   859  	cancel1(cause)
   860  	close(donec)
   861  
   862  	<-ctx3.Done()
   863  	if got, want := ctx3.Err(), Canceled; got != want {
   864  		t.Errorf("child not canceled; got = %v, want = %v", got, want)
   865  	}
   866  	if got, want := Cause(ctx3), cause; got != want {
   867  		t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
   868  	}
   869  }
   870  
   871  // customCauseContext is a custom Context used to test context.Cause.
   872  type customCauseContext struct {
   873  	mu   sync.Mutex
   874  	done chan struct{}
   875  	err  error
   876  
   877  	cancelChild CancelFunc
   878  }
   879  
   880  func (ccc *customCauseContext) Deadline() (deadline time.Time, ok bool) {
   881  	return
   882  }
   883  
   884  func (ccc *customCauseContext) Done() <-chan struct{} {
   885  	ccc.mu.Lock()
   886  	defer ccc.mu.Unlock()
   887  	return ccc.done
   888  }
   889  
   890  func (ccc *customCauseContext) Err() error {
   891  	ccc.mu.Lock()
   892  	defer ccc.mu.Unlock()
   893  	return ccc.err
   894  }
   895  
   896  func (ccc *customCauseContext) Value(key any) any {
   897  	return nil
   898  }
   899  
   900  func (ccc *customCauseContext) cancel() {
   901  	ccc.mu.Lock()
   902  	ccc.err = Canceled
   903  	close(ccc.done)
   904  	cancelChild := ccc.cancelChild
   905  	ccc.mu.Unlock()
   906  
   907  	if cancelChild != nil {
   908  		cancelChild()
   909  	}
   910  }
   911  
   912  func (ccc *customCauseContext) setCancelChild(cancelChild CancelFunc) {
   913  	ccc.cancelChild = cancelChild
   914  }
   915  
   916  func TestCustomContextCause(t *testing.T) {
   917  	// Test if we cancel a custom context, Err and Cause return Canceled.
   918  	ccc := &customCauseContext{
   919  		done: make(chan struct{}),
   920  	}
   921  	ccc.cancel()
   922  	if got := ccc.Err(); got != Canceled {
   923  		t.Errorf("ccc.Err() = %v, want %v", got, Canceled)
   924  	}
   925  	if got := Cause(ccc); got != Canceled {
   926  		t.Errorf("Cause(ccc) = %v, want %v", got, Canceled)
   927  	}
   928  
   929  	// Test that if we pass a custom context to WithCancelCause,
   930  	// and then cancel that child context with a cause,
   931  	// that the cause of the child canceled context is correct
   932  	// but that the parent custom context is not canceled.
   933  	ccc = &customCauseContext{
   934  		done: make(chan struct{}),
   935  	}
   936  	ctx, causeFunc := WithCancelCause(ccc)
   937  	cause := errors.New("TestCustomContextCause")
   938  	causeFunc(cause)
   939  	if got := ctx.Err(); got != Canceled {
   940  		t.Errorf("after CancelCauseFunc ctx.Err() = %v, want %v", got, Canceled)
   941  	}
   942  	if got := Cause(ctx); got != cause {
   943  		t.Errorf("after CancelCauseFunc Cause(ctx) = %v, want %v", got, cause)
   944  	}
   945  	if got := ccc.Err(); got != nil {
   946  		t.Errorf("after CancelCauseFunc ccc.Err() = %v, want %v", got, nil)
   947  	}
   948  	if got := Cause(ccc); got != nil {
   949  		t.Errorf("after CancelCauseFunc Cause(ccc) = %v, want %v", got, nil)
   950  	}
   951  
   952  	// Test that if we now cancel the parent custom context,
   953  	// the cause of the child canceled context is still correct,
   954  	// and the parent custom context is canceled without a cause.
   955  	ccc.cancel()
   956  	if got := ctx.Err(); got != Canceled {
   957  		t.Errorf("after CancelCauseFunc ctx.Err() = %v, want %v", got, Canceled)
   958  	}
   959  	if got := Cause(ctx); got != cause {
   960  		t.Errorf("after CancelCauseFunc Cause(ctx) = %v, want %v", got, cause)
   961  	}
   962  	if got := ccc.Err(); got != Canceled {
   963  		t.Errorf("after CancelCauseFunc ccc.Err() = %v, want %v", got, Canceled)
   964  	}
   965  	if got := Cause(ccc); got != Canceled {
   966  		t.Errorf("after CancelCauseFunc Cause(ccc) = %v, want %v", got, Canceled)
   967  	}
   968  
   969  	// Test that if we associate a custom context with a child,
   970  	// then canceling the custom context cancels the child.
   971  	ccc = &customCauseContext{
   972  		done: make(chan struct{}),
   973  	}
   974  	ctx, cancelFunc := WithCancel(ccc)
   975  	ccc.setCancelChild(cancelFunc)
   976  	ccc.cancel()
   977  	if got := ctx.Err(); got != Canceled {
   978  		t.Errorf("after CancelCauseFunc ctx.Err() = %v, want %v", got, Canceled)
   979  	}
   980  	if got := Cause(ctx); got != Canceled {
   981  		t.Errorf("after CancelCauseFunc Cause(ctx) = %v, want %v", got, Canceled)
   982  	}
   983  	if got := ccc.Err(); got != Canceled {
   984  		t.Errorf("after CancelCauseFunc ccc.Err() = %v, want %v", got, Canceled)
   985  	}
   986  	if got := Cause(ccc); got != Canceled {
   987  		t.Errorf("after CancelCauseFunc Cause(ccc) = %v, want %v", got, Canceled)
   988  	}
   989  }
   990  
   991  func TestAfterFuncCalledAfterCancel(t *testing.T) {
   992  	ctx, cancel := WithCancel(Background())
   993  	donec := make(chan struct{})
   994  	stop := AfterFunc(ctx, func() {
   995  		close(donec)
   996  	})
   997  	select {
   998  	case <-donec:
   999  		t.Fatalf("AfterFunc called before context is done")
  1000  	case <-time.After(shortDuration):
  1001  	}
  1002  	cancel()
  1003  	select {
  1004  	case <-donec:
  1005  	case <-time.After(veryLongDuration):
  1006  		t.Fatalf("AfterFunc not called after context is canceled")
  1007  	}
  1008  	if stop() {
  1009  		t.Fatalf("stop() = true, want false")
  1010  	}
  1011  }
  1012  
  1013  func TestAfterFuncCalledAfterTimeout(t *testing.T) {
  1014  	ctx, cancel := WithTimeout(Background(), shortDuration)
  1015  	defer cancel()
  1016  	donec := make(chan struct{})
  1017  	AfterFunc(ctx, func() {
  1018  		close(donec)
  1019  	})
  1020  	select {
  1021  	case <-donec:
  1022  	case <-time.After(veryLongDuration):
  1023  		t.Fatalf("AfterFunc not called after context is canceled")
  1024  	}
  1025  }
  1026  
  1027  func TestAfterFuncCalledImmediately(t *testing.T) {
  1028  	ctx, cancel := WithCancel(Background())
  1029  	cancel()
  1030  	donec := make(chan struct{})
  1031  	AfterFunc(ctx, func() {
  1032  		close(donec)
  1033  	})
  1034  	select {
  1035  	case <-donec:
  1036  	case <-time.After(veryLongDuration):
  1037  		t.Fatalf("AfterFunc not called for already-canceled context")
  1038  	}
  1039  }
  1040  
  1041  func TestAfterFuncNotCalledAfterStop(t *testing.T) {
  1042  	ctx, cancel := WithCancel(Background())
  1043  	donec := make(chan struct{})
  1044  	stop := AfterFunc(ctx, func() {
  1045  		close(donec)
  1046  	})
  1047  	if !stop() {
  1048  		t.Fatalf("stop() = false, want true")
  1049  	}
  1050  	cancel()
  1051  	select {
  1052  	case <-donec:
  1053  		t.Fatalf("AfterFunc called for already-canceled context")
  1054  	case <-time.After(shortDuration):
  1055  	}
  1056  	if stop() {
  1057  		t.Fatalf("stop() = true, want false")
  1058  	}
  1059  }
  1060  
  1061  // This test verifies that cancelling a context does not block waiting for AfterFuncs to finish.
  1062  func TestAfterFuncCalledAsynchronously(t *testing.T) {
  1063  	ctx, cancel := WithCancel(Background())
  1064  	donec := make(chan struct{})
  1065  	stop := AfterFunc(ctx, func() {
  1066  		// The channel send blocks until donec is read from.
  1067  		donec <- struct{}{}
  1068  	})
  1069  	defer stop()
  1070  	cancel()
  1071  	// After cancel returns, read from donec and unblock the AfterFunc.
  1072  	select {
  1073  	case <-donec:
  1074  	case <-time.After(veryLongDuration):
  1075  		t.Fatalf("AfterFunc not called after context is canceled")
  1076  	}
  1077  }
  1078  

View as plain text