Source file src/net/splice_test.go

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build linux
     6  
     7  package net
     8  
     9  import (
    10  	"io"
    11  	"log"
    12  	"os"
    13  	"os/exec"
    14  	"strconv"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  func TestSplice(t *testing.T) {
    21  	t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
    22  	if !testableNetwork("unixgram") {
    23  		t.Skip("skipping unix-to-tcp tests")
    24  	}
    25  	t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
    26  	t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") })
    27  	t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
    28  	t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
    29  	t.Run("no-unixpacket", testSpliceNoUnixpacket)
    30  	t.Run("no-unixgram", testSpliceNoUnixgram)
    31  }
    32  
    33  func testSpliceToFile(t *testing.T, upNet, downNet string) {
    34  	t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.testFile)
    35  	t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.testFile)
    36  	t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.testFile)
    37  	t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.testFile)
    38  	t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.testFile)
    39  	t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.testFile)
    40  }
    41  
    42  func testSplice(t *testing.T, upNet, downNet string) {
    43  	t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
    44  	t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
    45  	t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
    46  	t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
    47  	t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
    48  	t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
    49  	t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
    50  	t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
    51  }
    52  
    53  type spliceTestCase struct {
    54  	upNet, downNet string
    55  
    56  	chunkSize, totalSize int
    57  	limitReadSize        int
    58  }
    59  
    60  func (tc spliceTestCase) test(t *testing.T) {
    61  	clientUp, serverUp := spliceTestSocketPair(t, tc.upNet)
    62  	defer serverUp.Close()
    63  	cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
    64  	if err != nil {
    65  		t.Fatal(err)
    66  	}
    67  	defer cleanup()
    68  	clientDown, serverDown := spliceTestSocketPair(t, tc.downNet)
    69  	defer serverDown.Close()
    70  	cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
    71  	if err != nil {
    72  		t.Fatal(err)
    73  	}
    74  	defer cleanup()
    75  	var (
    76  		r    io.Reader = serverUp
    77  		size           = tc.totalSize
    78  	)
    79  	if tc.limitReadSize > 0 {
    80  		if tc.limitReadSize < size {
    81  			size = tc.limitReadSize
    82  		}
    83  
    84  		r = &io.LimitedReader{
    85  			N: int64(tc.limitReadSize),
    86  			R: serverUp,
    87  		}
    88  		defer serverUp.Close()
    89  	}
    90  	n, err := io.Copy(serverDown, r)
    91  	serverDown.Close()
    92  	if err != nil {
    93  		t.Fatal(err)
    94  	}
    95  	if want := int64(size); want != n {
    96  		t.Errorf("want %d bytes spliced, got %d", want, n)
    97  	}
    98  
    99  	if tc.limitReadSize > 0 {
   100  		wantN := 0
   101  		if tc.limitReadSize > size {
   102  			wantN = tc.limitReadSize - size
   103  		}
   104  
   105  		if n := r.(*io.LimitedReader).N; n != int64(wantN) {
   106  			t.Errorf("r.N = %d, want %d", n, wantN)
   107  		}
   108  	}
   109  }
   110  
   111  func (tc spliceTestCase) testFile(t *testing.T) {
   112  	f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
   113  	if err != nil {
   114  		t.Fatal(err)
   115  	}
   116  	defer f.Close()
   117  
   118  	client, server := spliceTestSocketPair(t, tc.upNet)
   119  	defer server.Close()
   120  
   121  	cleanup, err := startSpliceClient(client, "w", tc.chunkSize, tc.totalSize)
   122  	if err != nil {
   123  		client.Close()
   124  		t.Fatal("failed to start splice client:", err)
   125  	}
   126  	defer cleanup()
   127  
   128  	var (
   129  		r          io.Reader = server
   130  		actualSize           = tc.totalSize
   131  	)
   132  	if tc.limitReadSize > 0 {
   133  		if tc.limitReadSize < actualSize {
   134  			actualSize = tc.limitReadSize
   135  		}
   136  
   137  		r = &io.LimitedReader{
   138  			N: int64(tc.limitReadSize),
   139  			R: r,
   140  		}
   141  	}
   142  
   143  	got, err := io.Copy(f, r)
   144  	if err != nil {
   145  		t.Fatalf("failed to ReadFrom with error: %v", err)
   146  	}
   147  	if want := int64(actualSize); got != want {
   148  		t.Errorf("got %d bytes, want %d", got, want)
   149  	}
   150  	if tc.limitReadSize > 0 {
   151  		wantN := 0
   152  		if tc.limitReadSize > actualSize {
   153  			wantN = tc.limitReadSize - actualSize
   154  		}
   155  
   156  		if gotN := r.(*io.LimitedReader).N; gotN != int64(wantN) {
   157  			t.Errorf("r.N = %d, want %d", gotN, wantN)
   158  		}
   159  	}
   160  }
   161  
   162  func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
   163  	// UnixConn doesn't implement io.ReaderFrom, which will fail
   164  	// the following test in asserting a UnixConn to be an io.ReaderFrom,
   165  	// so skip this test.
   166  	if upNet == "unix" || downNet == "unix" {
   167  		t.Skip("skipping test on unix socket")
   168  	}
   169  
   170  	clientUp, serverUp := spliceTestSocketPair(t, upNet)
   171  	defer clientUp.Close()
   172  	clientDown, serverDown := spliceTestSocketPair(t, downNet)
   173  	defer clientDown.Close()
   174  
   175  	serverUp.Close()
   176  
   177  	// We'd like to call net.spliceFrom here and check the handled return
   178  	// value, but we disable splice on old Linux kernels.
   179  	//
   180  	// In that case, poll.Splice and net.spliceFrom return a non-nil error
   181  	// and handled == false. We'd ideally like to see handled == true
   182  	// because the source reader is at EOF, but if we're running on an old
   183  	// kernel, and splice is disabled, we won't see EOF from net.spliceFrom,
   184  	// because we won't touch the reader at all.
   185  	//
   186  	// Trying to untangle the errors from net.spliceFrom and match them
   187  	// against the errors created by the poll package would be brittle,
   188  	// so this is a higher level test.
   189  	//
   190  	// The following ReadFrom should return immediately, regardless of
   191  	// whether splice is disabled or not. The other side should then
   192  	// get a goodbye signal. Test for the goodbye signal.
   193  	msg := "bye"
   194  	go func() {
   195  		serverDown.(io.ReaderFrom).ReadFrom(serverUp)
   196  		io.WriteString(serverDown, msg)
   197  		serverDown.Close()
   198  	}()
   199  
   200  	buf := make([]byte, 3)
   201  	_, err := io.ReadFull(clientDown, buf)
   202  	if err != nil {
   203  		t.Errorf("clientDown: %v", err)
   204  	}
   205  	if string(buf) != msg {
   206  		t.Errorf("clientDown got %q, want %q", buf, msg)
   207  	}
   208  }
   209  
   210  func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
   211  	front := newLocalListener(t, upNet)
   212  	defer front.Close()
   213  	back := newLocalListener(t, downNet)
   214  	defer back.Close()
   215  
   216  	var wg sync.WaitGroup
   217  	wg.Add(2)
   218  
   219  	proxy := func() {
   220  		src, err := front.Accept()
   221  		if err != nil {
   222  			return
   223  		}
   224  		dst, err := Dial(downNet, back.Addr().String())
   225  		if err != nil {
   226  			return
   227  		}
   228  		defer dst.Close()
   229  		defer src.Close()
   230  		go func() {
   231  			io.Copy(src, dst)
   232  			wg.Done()
   233  		}()
   234  		go func() {
   235  			io.Copy(dst, src)
   236  			wg.Done()
   237  		}()
   238  	}
   239  
   240  	go proxy()
   241  
   242  	toFront, err := Dial(upNet, front.Addr().String())
   243  	if err != nil {
   244  		t.Fatal(err)
   245  	}
   246  
   247  	io.WriteString(toFront, "foo")
   248  	toFront.Close()
   249  
   250  	fromProxy, err := back.Accept()
   251  	if err != nil {
   252  		t.Fatal(err)
   253  	}
   254  	defer fromProxy.Close()
   255  
   256  	_, err = io.ReadAll(fromProxy)
   257  	if err != nil {
   258  		t.Fatal(err)
   259  	}
   260  
   261  	wg.Wait()
   262  }
   263  
   264  func testSpliceNoUnixpacket(t *testing.T) {
   265  	clientUp, serverUp := spliceTestSocketPair(t, "unixpacket")
   266  	defer clientUp.Close()
   267  	defer serverUp.Close()
   268  	clientDown, serverDown := spliceTestSocketPair(t, "tcp")
   269  	defer clientDown.Close()
   270  	defer serverDown.Close()
   271  	// If splice called poll.Splice here, we'd get err == syscall.EINVAL
   272  	// and handled == false.  If poll.Splice gets an EINVAL on the first
   273  	// try, it assumes the kernel it's running on doesn't support splice
   274  	// for unix sockets and returns handled == false. This works for our
   275  	// purposes by somewhat of an accident, but is not entirely correct.
   276  	//
   277  	// What we want is err == nil and handled == false, i.e. we never
   278  	// called poll.Splice, because we know the unix socket's network.
   279  	_, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp)
   280  	if err != nil || handled != false {
   281  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   282  	}
   283  }
   284  
   285  func testSpliceNoUnixgram(t *testing.T) {
   286  	addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
   287  	if err != nil {
   288  		t.Fatal(err)
   289  	}
   290  	defer os.Remove(addr.Name)
   291  	up, err := ListenUnixgram("unixgram", addr)
   292  	if err != nil {
   293  		t.Fatal(err)
   294  	}
   295  	defer up.Close()
   296  	clientDown, serverDown := spliceTestSocketPair(t, "tcp")
   297  	defer clientDown.Close()
   298  	defer serverDown.Close()
   299  	// Analogous to testSpliceNoUnixpacket.
   300  	_, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up)
   301  	if err != nil || handled != false {
   302  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   303  	}
   304  }
   305  
   306  func BenchmarkSplice(b *testing.B) {
   307  	testHookUninstaller.Do(uninstallTestHooks)
   308  
   309  	b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
   310  	b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
   311  	b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") })
   312  }
   313  
   314  func benchSplice(b *testing.B, upNet, downNet string) {
   315  	for i := 0; i <= 10; i++ {
   316  		chunkSize := 1 << uint(i+10)
   317  		tc := spliceTestCase{
   318  			upNet:     upNet,
   319  			downNet:   downNet,
   320  			chunkSize: chunkSize,
   321  		}
   322  
   323  		b.Run(strconv.Itoa(chunkSize), tc.bench)
   324  	}
   325  }
   326  
   327  func (tc spliceTestCase) bench(b *testing.B) {
   328  	// To benchmark the genericReadFrom code path, set this to false.
   329  	useSplice := true
   330  
   331  	clientUp, serverUp := spliceTestSocketPair(b, tc.upNet)
   332  	defer serverUp.Close()
   333  
   334  	cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
   335  	if err != nil {
   336  		b.Fatal(err)
   337  	}
   338  	defer cleanup()
   339  
   340  	clientDown, serverDown := spliceTestSocketPair(b, tc.downNet)
   341  	defer serverDown.Close()
   342  
   343  	cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
   344  	if err != nil {
   345  		b.Fatal(err)
   346  	}
   347  	defer cleanup()
   348  
   349  	b.SetBytes(int64(tc.chunkSize))
   350  	b.ResetTimer()
   351  
   352  	if useSplice {
   353  		_, err := io.Copy(serverDown, serverUp)
   354  		if err != nil {
   355  			b.Fatal(err)
   356  		}
   357  	} else {
   358  		type onlyReader struct {
   359  			io.Reader
   360  		}
   361  		_, err := io.Copy(serverDown, onlyReader{serverUp})
   362  		if err != nil {
   363  			b.Fatal(err)
   364  		}
   365  	}
   366  }
   367  
   368  func spliceTestSocketPair(t testing.TB, net string) (client, server Conn) {
   369  	t.Helper()
   370  	ln := newLocalListener(t, net)
   371  	defer ln.Close()
   372  	var cerr, serr error
   373  	acceptDone := make(chan struct{})
   374  	go func() {
   375  		server, serr = ln.Accept()
   376  		acceptDone <- struct{}{}
   377  	}()
   378  	client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
   379  	<-acceptDone
   380  	if cerr != nil {
   381  		if server != nil {
   382  			server.Close()
   383  		}
   384  		t.Fatal(cerr)
   385  	}
   386  	if serr != nil {
   387  		if client != nil {
   388  			client.Close()
   389  		}
   390  		t.Fatal(serr)
   391  	}
   392  	return client, server
   393  }
   394  
   395  func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
   396  	f, err := conn.(interface{ File() (*os.File, error) }).File()
   397  	if err != nil {
   398  		return nil, err
   399  	}
   400  
   401  	cmd := exec.Command(os.Args[0], os.Args[1:]...)
   402  	cmd.Env = []string{
   403  		"GO_NET_TEST_SPLICE=1",
   404  		"GO_NET_TEST_SPLICE_OP=" + op,
   405  		"GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
   406  		"GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize),
   407  		"TMPDIR=" + os.Getenv("TMPDIR"),
   408  	}
   409  	cmd.ExtraFiles = append(cmd.ExtraFiles, f)
   410  	cmd.Stdout = os.Stdout
   411  	cmd.Stderr = os.Stderr
   412  
   413  	if err := cmd.Start(); err != nil {
   414  		return nil, err
   415  	}
   416  
   417  	donec := make(chan struct{})
   418  	go func() {
   419  		cmd.Wait()
   420  		conn.Close()
   421  		f.Close()
   422  		close(donec)
   423  	}()
   424  
   425  	return func() {
   426  		select {
   427  		case <-donec:
   428  		case <-time.After(5 * time.Second):
   429  			log.Printf("killing splice client after 5 second shutdown timeout")
   430  			cmd.Process.Kill()
   431  			select {
   432  			case <-donec:
   433  			case <-time.After(5 * time.Second):
   434  				log.Printf("splice client didn't die after 10 seconds")
   435  			}
   436  		}
   437  	}, nil
   438  }
   439  
   440  func init() {
   441  	if os.Getenv("GO_NET_TEST_SPLICE") == "" {
   442  		return
   443  	}
   444  	defer os.Exit(0)
   445  
   446  	f := os.NewFile(uintptr(3), "splice-test-conn")
   447  	defer f.Close()
   448  
   449  	conn, err := FileConn(f)
   450  	if err != nil {
   451  		log.Fatal(err)
   452  	}
   453  
   454  	var chunkSize int
   455  	if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil {
   456  		log.Fatal(err)
   457  	}
   458  	buf := make([]byte, chunkSize)
   459  
   460  	var totalSize int
   461  	if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil {
   462  		log.Fatal(err)
   463  	}
   464  
   465  	var fn func([]byte) (int, error)
   466  	switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op {
   467  	case "r":
   468  		fn = conn.Read
   469  	case "w":
   470  		defer conn.Close()
   471  
   472  		fn = conn.Write
   473  	default:
   474  		log.Fatalf("unknown op %q", op)
   475  	}
   476  
   477  	var n int
   478  	for count := 0; count < totalSize; count += n {
   479  		if count+chunkSize > totalSize {
   480  			buf = buf[:totalSize-count]
   481  		}
   482  
   483  		var err error
   484  		if n, err = fn(buf); err != nil {
   485  			return
   486  		}
   487  	}
   488  }
   489  
   490  func BenchmarkSpliceFile(b *testing.B) {
   491  	b.Run("tcp-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "tcp") })
   492  	b.Run("unix-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "unix") })
   493  }
   494  
   495  func benchmarkSpliceFile(b *testing.B, proto string) {
   496  	for i := 0; i <= 10; i++ {
   497  		size := 1 << (i + 10)
   498  		bench := spliceFileBench{
   499  			proto:     proto,
   500  			chunkSize: size,
   501  		}
   502  		b.Run(strconv.Itoa(size), bench.benchSpliceFile)
   503  	}
   504  }
   505  
   506  type spliceFileBench struct {
   507  	proto     string
   508  	chunkSize int
   509  }
   510  
   511  func (bench spliceFileBench) benchSpliceFile(b *testing.B) {
   512  	f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
   513  	if err != nil {
   514  		b.Fatal(err)
   515  	}
   516  	defer f.Close()
   517  
   518  	totalSize := b.N * bench.chunkSize
   519  
   520  	client, server := spliceTestSocketPair(b, bench.proto)
   521  	defer server.Close()
   522  
   523  	cleanup, err := startSpliceClient(client, "w", bench.chunkSize, totalSize)
   524  	if err != nil {
   525  		client.Close()
   526  		b.Fatalf("failed to start splice client: %v", err)
   527  	}
   528  	defer cleanup()
   529  
   530  	b.ReportAllocs()
   531  	b.SetBytes(int64(bench.chunkSize))
   532  	b.ResetTimer()
   533  
   534  	got, err := io.Copy(f, server)
   535  	if err != nil {
   536  		b.Fatalf("failed to ReadFrom with error: %v", err)
   537  	}
   538  	if want := int64(totalSize); got != want {
   539  		b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want)
   540  	}
   541  }
   542  

View as plain text