Source file src/net/splice_test.go

Documentation: net

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build linux
     6  
     7  package net
     8  
     9  import (
    10  	"io"
    11  	"io/ioutil"
    12  	"log"
    13  	"os"
    14  	"os/exec"
    15  	"strconv"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  )
    20  
    21  func TestSplice(t *testing.T) {
    22  	t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
    23  	if !testableNetwork("unixgram") {
    24  		t.Skip("skipping unix-to-tcp tests")
    25  	}
    26  	t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
    27  	t.Run("no-unixpacket", testSpliceNoUnixpacket)
    28  	t.Run("no-unixgram", testSpliceNoUnixgram)
    29  }
    30  
    31  func testSplice(t *testing.T, upNet, downNet string) {
    32  	t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
    33  	t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
    34  	t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
    35  	t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
    36  	t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
    37  	t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
    38  	t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
    39  	t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
    40  }
    41  
    42  type spliceTestCase struct {
    43  	upNet, downNet string
    44  
    45  	chunkSize, totalSize int
    46  	limitReadSize        int
    47  }
    48  
    49  func (tc spliceTestCase) test(t *testing.T) {
    50  	clientUp, serverUp, err := spliceTestSocketPair(tc.upNet)
    51  	if err != nil {
    52  		t.Fatal(err)
    53  	}
    54  	defer serverUp.Close()
    55  	cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
    56  	if err != nil {
    57  		t.Fatal(err)
    58  	}
    59  	defer cleanup()
    60  	clientDown, serverDown, err := spliceTestSocketPair(tc.downNet)
    61  	if err != nil {
    62  		t.Fatal(err)
    63  	}
    64  	defer serverDown.Close()
    65  	cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
    66  	if err != nil {
    67  		t.Fatal(err)
    68  	}
    69  	defer cleanup()
    70  	var (
    71  		r    io.Reader = serverUp
    72  		size           = tc.totalSize
    73  	)
    74  	if tc.limitReadSize > 0 {
    75  		if tc.limitReadSize < size {
    76  			size = tc.limitReadSize
    77  		}
    78  
    79  		r = &io.LimitedReader{
    80  			N: int64(tc.limitReadSize),
    81  			R: serverUp,
    82  		}
    83  		defer serverUp.Close()
    84  	}
    85  	n, err := io.Copy(serverDown, r)
    86  	serverDown.Close()
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  	if want := int64(size); want != n {
    91  		t.Errorf("want %d bytes spliced, got %d", want, n)
    92  	}
    93  
    94  	if tc.limitReadSize > 0 {
    95  		wantN := 0
    96  		if tc.limitReadSize > size {
    97  			wantN = tc.limitReadSize - size
    98  		}
    99  
   100  		if n := r.(*io.LimitedReader).N; n != int64(wantN) {
   101  			t.Errorf("r.N = %d, want %d", n, wantN)
   102  		}
   103  	}
   104  }
   105  
   106  func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
   107  	clientUp, serverUp, err := spliceTestSocketPair(upNet)
   108  	if err != nil {
   109  		t.Fatal(err)
   110  	}
   111  	defer clientUp.Close()
   112  	clientDown, serverDown, err := spliceTestSocketPair(downNet)
   113  	if err != nil {
   114  		t.Fatal(err)
   115  	}
   116  	defer clientDown.Close()
   117  
   118  	serverUp.Close()
   119  
   120  	// We'd like to call net.splice here and check the handled return
   121  	// value, but we disable splice on old Linux kernels.
   122  	//
   123  	// In that case, poll.Splice and net.splice return a non-nil error
   124  	// and handled == false. We'd ideally like to see handled == true
   125  	// because the source reader is at EOF, but if we're running on an old
   126  	// kernel, and splice is disabled, we won't see EOF from net.splice,
   127  	// because we won't touch the reader at all.
   128  	//
   129  	// Trying to untangle the errors from net.splice and match them
   130  	// against the errors created by the poll package would be brittle,
   131  	// so this is a higher level test.
   132  	//
   133  	// The following ReadFrom should return immediately, regardless of
   134  	// whether splice is disabled or not. The other side should then
   135  	// get a goodbye signal. Test for the goodbye signal.
   136  	msg := "bye"
   137  	go func() {
   138  		serverDown.(io.ReaderFrom).ReadFrom(serverUp)
   139  		io.WriteString(serverDown, msg)
   140  		serverDown.Close()
   141  	}()
   142  
   143  	buf := make([]byte, 3)
   144  	_, err = io.ReadFull(clientDown, buf)
   145  	if err != nil {
   146  		t.Errorf("clientDown: %v", err)
   147  	}
   148  	if string(buf) != msg {
   149  		t.Errorf("clientDown got %q, want %q", buf, msg)
   150  	}
   151  }
   152  
   153  func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
   154  	front, err := newLocalListener(upNet)
   155  	if err != nil {
   156  		t.Fatal(err)
   157  	}
   158  	defer front.Close()
   159  	back, err := newLocalListener(downNet)
   160  	if err != nil {
   161  		t.Fatal(err)
   162  	}
   163  	defer back.Close()
   164  
   165  	var wg sync.WaitGroup
   166  	wg.Add(2)
   167  
   168  	proxy := func() {
   169  		src, err := front.Accept()
   170  		if err != nil {
   171  			return
   172  		}
   173  		dst, err := Dial(downNet, back.Addr().String())
   174  		if err != nil {
   175  			return
   176  		}
   177  		defer dst.Close()
   178  		defer src.Close()
   179  		go func() {
   180  			io.Copy(src, dst)
   181  			wg.Done()
   182  		}()
   183  		go func() {
   184  			io.Copy(dst, src)
   185  			wg.Done()
   186  		}()
   187  	}
   188  
   189  	go proxy()
   190  
   191  	toFront, err := Dial(upNet, front.Addr().String())
   192  	if err != nil {
   193  		t.Fatal(err)
   194  	}
   195  
   196  	io.WriteString(toFront, "foo")
   197  	toFront.Close()
   198  
   199  	fromProxy, err := back.Accept()
   200  	if err != nil {
   201  		t.Fatal(err)
   202  	}
   203  	defer fromProxy.Close()
   204  
   205  	_, err = ioutil.ReadAll(fromProxy)
   206  	if err != nil {
   207  		t.Fatal(err)
   208  	}
   209  
   210  	wg.Wait()
   211  }
   212  
   213  func testSpliceNoUnixpacket(t *testing.T) {
   214  	clientUp, serverUp, err := spliceTestSocketPair("unixpacket")
   215  	if err != nil {
   216  		t.Fatal(err)
   217  	}
   218  	defer clientUp.Close()
   219  	defer serverUp.Close()
   220  	clientDown, serverDown, err := spliceTestSocketPair("tcp")
   221  	if err != nil {
   222  		t.Fatal(err)
   223  	}
   224  	defer clientDown.Close()
   225  	defer serverDown.Close()
   226  	// If splice called poll.Splice here, we'd get err == syscall.EINVAL
   227  	// and handled == false.  If poll.Splice gets an EINVAL on the first
   228  	// try, it assumes the kernel it's running on doesn't support splice
   229  	// for unix sockets and returns handled == false. This works for our
   230  	// purposes by somewhat of an accident, but is not entirely correct.
   231  	//
   232  	// What we want is err == nil and handled == false, i.e. we never
   233  	// called poll.Splice, because we know the unix socket's network.
   234  	_, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
   235  	if err != nil || handled != false {
   236  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   237  	}
   238  }
   239  
   240  func testSpliceNoUnixgram(t *testing.T) {
   241  	addr, err := ResolveUnixAddr("unixgram", testUnixAddr())
   242  	if err != nil {
   243  		t.Fatal(err)
   244  	}
   245  	up, err := ListenUnixgram("unixgram", addr)
   246  	if err != nil {
   247  		t.Fatal(err)
   248  	}
   249  	defer up.Close()
   250  	clientDown, serverDown, err := spliceTestSocketPair("tcp")
   251  	if err != nil {
   252  		t.Fatal(err)
   253  	}
   254  	defer clientDown.Close()
   255  	defer serverDown.Close()
   256  	// Analogous to testSpliceNoUnixpacket.
   257  	_, err, handled := splice(serverDown.(*TCPConn).fd, up)
   258  	if err != nil || handled != false {
   259  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   260  	}
   261  }
   262  
   263  func BenchmarkSplice(b *testing.B) {
   264  	testHookUninstaller.Do(uninstallTestHooks)
   265  
   266  	b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
   267  	b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
   268  }
   269  
   270  func benchSplice(b *testing.B, upNet, downNet string) {
   271  	for i := 0; i <= 10; i++ {
   272  		chunkSize := 1 << uint(i+10)
   273  		tc := spliceTestCase{
   274  			upNet:     upNet,
   275  			downNet:   downNet,
   276  			chunkSize: chunkSize,
   277  		}
   278  
   279  		b.Run(strconv.Itoa(chunkSize), tc.bench)
   280  	}
   281  }
   282  
   283  func (tc spliceTestCase) bench(b *testing.B) {
   284  	// To benchmark the genericReadFrom code path, set this to false.
   285  	useSplice := true
   286  
   287  	clientUp, serverUp, err := spliceTestSocketPair(tc.upNet)
   288  	if err != nil {
   289  		b.Fatal(err)
   290  	}
   291  	defer serverUp.Close()
   292  
   293  	cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
   294  	if err != nil {
   295  		b.Fatal(err)
   296  	}
   297  	defer cleanup()
   298  
   299  	clientDown, serverDown, err := spliceTestSocketPair(tc.downNet)
   300  	if err != nil {
   301  		b.Fatal(err)
   302  	}
   303  	defer serverDown.Close()
   304  
   305  	cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
   306  	if err != nil {
   307  		b.Fatal(err)
   308  	}
   309  	defer cleanup()
   310  
   311  	b.SetBytes(int64(tc.chunkSize))
   312  	b.ResetTimer()
   313  
   314  	if useSplice {
   315  		_, err := io.Copy(serverDown, serverUp)
   316  		if err != nil {
   317  			b.Fatal(err)
   318  		}
   319  	} else {
   320  		type onlyReader struct {
   321  			io.Reader
   322  		}
   323  		_, err := io.Copy(serverDown, onlyReader{serverUp})
   324  		if err != nil {
   325  			b.Fatal(err)
   326  		}
   327  	}
   328  }
   329  
   330  func spliceTestSocketPair(net string) (client, server Conn, err error) {
   331  	ln, err := newLocalListener(net)
   332  	if err != nil {
   333  		return nil, nil, err
   334  	}
   335  	defer ln.Close()
   336  	var cerr, serr error
   337  	acceptDone := make(chan struct{})
   338  	go func() {
   339  		server, serr = ln.Accept()
   340  		acceptDone <- struct{}{}
   341  	}()
   342  	client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
   343  	<-acceptDone
   344  	if cerr != nil {
   345  		if server != nil {
   346  			server.Close()
   347  		}
   348  		return nil, nil, cerr
   349  	}
   350  	if serr != nil {
   351  		if client != nil {
   352  			client.Close()
   353  		}
   354  		return nil, nil, serr
   355  	}
   356  	return client, server, nil
   357  }
   358  
   359  func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
   360  	f, err := conn.(interface{ File() (*os.File, error) }).File()
   361  	if err != nil {
   362  		return nil, err
   363  	}
   364  
   365  	cmd := exec.Command(os.Args[0], os.Args[1:]...)
   366  	cmd.Env = []string{
   367  		"GO_NET_TEST_SPLICE=1",
   368  		"GO_NET_TEST_SPLICE_OP=" + op,
   369  		"GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
   370  		"GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize),
   371  	}
   372  	cmd.ExtraFiles = append(cmd.ExtraFiles, f)
   373  	cmd.Stdout = os.Stdout
   374  	cmd.Stderr = os.Stderr
   375  
   376  	if err := cmd.Start(); err != nil {
   377  		return nil, err
   378  	}
   379  
   380  	donec := make(chan struct{})
   381  	go func() {
   382  		cmd.Wait()
   383  		conn.Close()
   384  		f.Close()
   385  		close(donec)
   386  	}()
   387  
   388  	return func() {
   389  		select {
   390  		case <-donec:
   391  		case <-time.After(5 * time.Second):
   392  			log.Printf("killing splice client after 5 second shutdown timeout")
   393  			cmd.Process.Kill()
   394  			select {
   395  			case <-donec:
   396  			case <-time.After(5 * time.Second):
   397  				log.Printf("splice client didn't die after 10 seconds")
   398  			}
   399  		}
   400  	}, nil
   401  }
   402  
   403  func init() {
   404  	if os.Getenv("GO_NET_TEST_SPLICE") == "" {
   405  		return
   406  	}
   407  	defer os.Exit(0)
   408  
   409  	f := os.NewFile(uintptr(3), "splice-test-conn")
   410  	defer f.Close()
   411  
   412  	conn, err := FileConn(f)
   413  	if err != nil {
   414  		log.Fatal(err)
   415  	}
   416  
   417  	var chunkSize int
   418  	if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil {
   419  		log.Fatal(err)
   420  	}
   421  	buf := make([]byte, chunkSize)
   422  
   423  	var totalSize int
   424  	if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil {
   425  		log.Fatal(err)
   426  	}
   427  
   428  	var fn func([]byte) (int, error)
   429  	switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op {
   430  	case "r":
   431  		fn = conn.Read
   432  	case "w":
   433  		defer conn.Close()
   434  
   435  		fn = conn.Write
   436  	default:
   437  		log.Fatalf("unknown op %q", op)
   438  	}
   439  
   440  	var n int
   441  	for count := 0; count < totalSize; count += n {
   442  		if count+chunkSize > totalSize {
   443  			buf = buf[:totalSize-count]
   444  		}
   445  
   446  		var err error
   447  		if n, err = fn(buf); err != nil {
   448  			return
   449  		}
   450  	}
   451  }
   452  

View as plain text