Source file src/os/zero_copy_linux.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package os
     6  
     7  import (
     8  	"internal/poll"
     9  	"io"
    10  	"syscall"
    11  )
    12  
    13  var (
    14  	pollCopyFileRange = poll.CopyFileRange
    15  	pollSplice        = poll.Splice
    16  	pollSendFile      = poll.SendFile
    17  )
    18  
    19  func (f *File) writeTo(w io.Writer) (written int64, handled bool, err error) {
    20  	pfd, network := getPollFDAndNetwork(w)
    21  	// TODO(panjf2000): same as File.spliceToFile.
    22  	if pfd == nil || !pfd.IsStream || !isUnixOrTCP(string(network)) {
    23  		return
    24  	}
    25  
    26  	sc, err := f.SyscallConn()
    27  	if err != nil {
    28  		return
    29  	}
    30  
    31  	rerr := sc.Read(func(fd uintptr) (done bool) {
    32  		written, err, handled = pollSendFile(pfd, int(fd), 1<<63-1)
    33  		return true
    34  	})
    35  
    36  	if err == nil {
    37  		err = rerr
    38  	}
    39  
    40  	return written, handled, wrapSyscallError("sendfile", err)
    41  }
    42  
    43  func (f *File) readFrom(r io.Reader) (written int64, handled bool, err error) {
    44  	// Neither copy_file_range(2) nor splice(2) supports destinations opened with
    45  	// O_APPEND, so don't bother to try zero-copy with these system calls.
    46  	//
    47  	// Visit https://man7.org/linux/man-pages/man2/copy_file_range.2.html#ERRORS and
    48  	// https://man7.org/linux/man-pages/man2/splice.2.html#ERRORS for details.
    49  	if f.appendMode {
    50  		return 0, false, nil
    51  	}
    52  
    53  	written, handled, err = f.copyFileRange(r)
    54  	if handled {
    55  		return
    56  	}
    57  	return f.spliceToFile(r)
    58  }
    59  
    60  func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error) {
    61  	var (
    62  		remain int64
    63  		lr     *io.LimitedReader
    64  	)
    65  	if lr, r, remain = tryLimitedReader(r); remain <= 0 {
    66  		return 0, true, nil
    67  	}
    68  
    69  	pfd, _ := getPollFDAndNetwork(r)
    70  	// TODO(panjf2000): run some tests to see if we should unlock the non-streams for splice.
    71  	// Streams benefit the most from the splice(2), non-streams are not even supported in old kernels
    72  	// where splice(2) will just return EINVAL; newer kernels support non-streams like UDP, but I really
    73  	// doubt that splice(2) could help non-streams, cuz they usually send small frames respectively
    74  	// and one splice call would result in one frame.
    75  	// splice(2) is suitable for large data but the generation of fragments defeats its edge here.
    76  	// Therefore, don't bother to try splice if the r is not a streaming descriptor.
    77  	if pfd == nil || !pfd.IsStream {
    78  		return
    79  	}
    80  
    81  	var syscallName string
    82  	written, handled, syscallName, err = pollSplice(&f.pfd, pfd, remain)
    83  
    84  	if lr != nil {
    85  		lr.N = remain - written
    86  	}
    87  
    88  	return written, handled, wrapSyscallError(syscallName, err)
    89  }
    90  
    91  func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err error) {
    92  	var (
    93  		remain int64
    94  		lr     *io.LimitedReader
    95  	)
    96  	if lr, r, remain = tryLimitedReader(r); remain <= 0 {
    97  		return 0, true, nil
    98  	}
    99  
   100  	var src *File
   101  	switch v := r.(type) {
   102  	case *File:
   103  		src = v
   104  	case fileWithoutWriteTo:
   105  		src = v.File
   106  	default:
   107  		return 0, false, nil
   108  	}
   109  
   110  	if src.checkValid("ReadFrom") != nil {
   111  		// Avoid returning the error as we report handled as false,
   112  		// leave further error handling as the responsibility of the caller.
   113  		return 0, false, nil
   114  	}
   115  
   116  	written, handled, err = pollCopyFileRange(&f.pfd, &src.pfd, remain)
   117  	if lr != nil {
   118  		lr.N -= written
   119  	}
   120  	return written, handled, wrapSyscallError("copy_file_range", err)
   121  }
   122  
   123  // getPollFDAndNetwork tries to get the poll.FD and network type from the given interface
   124  // by expecting the underlying type of i to be the implementation of syscall.Conn
   125  // that contains a *net.rawConn.
   126  func getPollFDAndNetwork(i any) (*poll.FD, poll.String) {
   127  	sc, ok := i.(syscall.Conn)
   128  	if !ok {
   129  		return nil, ""
   130  	}
   131  	rc, err := sc.SyscallConn()
   132  	if err != nil {
   133  		return nil, ""
   134  	}
   135  	irc, ok := rc.(interface {
   136  		PollFD() *poll.FD
   137  		Network() poll.String
   138  	})
   139  	if !ok {
   140  		return nil, ""
   141  	}
   142  	return irc.PollFD(), irc.Network()
   143  }
   144  
   145  // tryLimitedReader tries to assert the io.Reader to io.LimitedReader, it returns the io.LimitedReader,
   146  // the underlying io.Reader and the remaining amount of bytes if the assertion succeeds,
   147  // otherwise it just returns the original io.Reader and the theoretical unlimited remaining amount of bytes.
   148  func tryLimitedReader(r io.Reader) (*io.LimitedReader, io.Reader, int64) {
   149  	var remain int64 = 1<<63 - 1 // by default, copy until EOF
   150  
   151  	lr, ok := r.(*io.LimitedReader)
   152  	if !ok {
   153  		return nil, r, remain
   154  	}
   155  
   156  	remain = lr.N
   157  	return lr, lr.R, remain
   158  }
   159  
   160  func isUnixOrTCP(network string) bool {
   161  	switch network {
   162  	case "tcp", "tcp4", "tcp6", "unix":
   163  		return true
   164  	default:
   165  		return false
   166  	}
   167  }
   168  

View as plain text