Source file src/os/writeto_linux_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package os_test
     6  
     7  import (
     8  	"bytes"
     9  	"internal/poll"
    10  	"io"
    11  	"math/rand"
    12  	"net"
    13  	. "os"
    14  	"strconv"
    15  	"syscall"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  func TestSendFile(t *testing.T) {
    21  	sizes := []int{
    22  		1,
    23  		42,
    24  		1025,
    25  		syscall.Getpagesize() + 1,
    26  		32769,
    27  	}
    28  	t.Run("sendfile-to-unix", func(t *testing.T) {
    29  		for _, size := range sizes {
    30  			t.Run(strconv.Itoa(size), func(t *testing.T) {
    31  				testSendFile(t, "unix", int64(size))
    32  			})
    33  		}
    34  	})
    35  	t.Run("sendfile-to-tcp", func(t *testing.T) {
    36  		for _, size := range sizes {
    37  			t.Run(strconv.Itoa(size), func(t *testing.T) {
    38  				testSendFile(t, "tcp", int64(size))
    39  			})
    40  		}
    41  	})
    42  }
    43  
    44  func testSendFile(t *testing.T, proto string, size int64) {
    45  	dst, src, recv, data, hook := newSendFileTest(t, proto, size)
    46  
    47  	// Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile
    48  	n, err := io.Copy(dst, src)
    49  	if err != nil {
    50  		t.Fatalf("io.Copy error: %v", err)
    51  	}
    52  
    53  	// We should have called poll.Splice with the right file descriptor arguments.
    54  	if n > 0 && !hook.called {
    55  		t.Fatal("expected to called poll.SendFile")
    56  	}
    57  	if hook.called && hook.srcfd != int(src.Fd()) {
    58  		t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
    59  	}
    60  	sc, ok := dst.(syscall.Conn)
    61  	if !ok {
    62  		t.Fatalf("destination is not a syscall.Conn")
    63  	}
    64  	rc, err := sc.SyscallConn()
    65  	if err != nil {
    66  		t.Fatalf("destination SyscallConn error: %v", err)
    67  	}
    68  	if err = rc.Control(func(fd uintptr) {
    69  		if hook.called && hook.dstfd != int(fd) {
    70  			t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd))
    71  		}
    72  	}); err != nil {
    73  		t.Fatalf("destination Conn Control error: %v", err)
    74  	}
    75  
    76  	// Verify the data size and content.
    77  	dataSize := len(data)
    78  	dstData := make([]byte, dataSize)
    79  	m, err := io.ReadFull(recv, dstData)
    80  	if err != nil {
    81  		t.Fatalf("server Conn Read error: %v", err)
    82  	}
    83  	if n != int64(dataSize) {
    84  		t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize)
    85  	}
    86  	if m != dataSize {
    87  		t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize)
    88  	}
    89  	if !bytes.Equal(dstData, data) {
    90  		t.Errorf("data mismatch, got %s, want %s", dstData, data)
    91  	}
    92  }
    93  
    94  // newSendFileTest initializes a new test for sendfile.
    95  //
    96  // It creates source file and destination sockets, and populates the source file
    97  // with random data of the specified size. It also hooks package os' call
    98  // to poll.Sendfile and returns the hook so it can be inspected.
    99  func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) {
   100  	t.Helper()
   101  
   102  	hook := hookSendFile(t)
   103  
   104  	client, server := createSocketPair(t, proto)
   105  	tempFile, data := createTempFile(t, size)
   106  
   107  	return client, tempFile, server, data, hook
   108  }
   109  
   110  func hookSendFile(t *testing.T) *sendFileHook {
   111  	h := new(sendFileHook)
   112  	h.install()
   113  	t.Cleanup(h.uninstall)
   114  	return h
   115  }
   116  
   117  type sendFileHook struct {
   118  	called bool
   119  	dstfd  int
   120  	srcfd  int
   121  	remain int64
   122  
   123  	written int64
   124  	handled bool
   125  	err     error
   126  
   127  	original func(dst *poll.FD, src int, remain int64) (int64, error, bool)
   128  }
   129  
   130  func (h *sendFileHook) install() {
   131  	h.original = *PollSendFile
   132  	*PollSendFile = func(dst *poll.FD, src int, remain int64) (int64, error, bool) {
   133  		h.called = true
   134  		h.dstfd = dst.Sysfd
   135  		h.srcfd = src
   136  		h.remain = remain
   137  		h.written, h.err, h.handled = h.original(dst, src, remain)
   138  		return h.written, h.err, h.handled
   139  	}
   140  }
   141  
   142  func (h *sendFileHook) uninstall() {
   143  	*PollSendFile = h.original
   144  }
   145  
   146  func createTempFile(t *testing.T, size int64) (*File, []byte) {
   147  	f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket")
   148  	if err != nil {
   149  		t.Fatalf("failed to create temporary file: %v", err)
   150  	}
   151  	t.Cleanup(func() {
   152  		f.Close()
   153  	})
   154  
   155  	randSeed := time.Now().Unix()
   156  	t.Logf("random data seed: %d\n", randSeed)
   157  	prng := rand.New(rand.NewSource(randSeed))
   158  	data := make([]byte, size)
   159  	prng.Read(data)
   160  	if _, err := f.Write(data); err != nil {
   161  		t.Fatalf("failed to create and feed the file: %v", err)
   162  	}
   163  	if err := f.Sync(); err != nil {
   164  		t.Fatalf("failed to save the file: %v", err)
   165  	}
   166  	if _, err := f.Seek(0, io.SeekStart); err != nil {
   167  		t.Fatalf("failed to rewind the file: %v", err)
   168  	}
   169  
   170  	return f, data
   171  }
   172  

View as plain text