Source file src/testing/iotest/reader.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package iotest implements Readers and Writers useful mainly for testing.
     6  package iotest
     7  
     8  import (
     9  	"bytes"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  )
    14  
    15  // OneByteReader returns a Reader that implements
    16  // each non-empty Read by reading one byte from r.
    17  func OneByteReader(r io.Reader) io.Reader { return &oneByteReader{r} }
    18  
    19  type oneByteReader struct {
    20  	r io.Reader
    21  }
    22  
    23  func (r *oneByteReader) Read(p []byte) (int, error) {
    24  	if len(p) == 0 {
    25  		return 0, nil
    26  	}
    27  	return r.r.Read(p[0:1])
    28  }
    29  
    30  // HalfReader returns a Reader that implements Read
    31  // by reading half as many requested bytes from r.
    32  func HalfReader(r io.Reader) io.Reader { return &halfReader{r} }
    33  
    34  type halfReader struct {
    35  	r io.Reader
    36  }
    37  
    38  func (r *halfReader) Read(p []byte) (int, error) {
    39  	return r.r.Read(p[0 : (len(p)+1)/2])
    40  }
    41  
    42  // DataErrReader changes the way errors are handled by a Reader. Normally, a
    43  // Reader returns an error (typically EOF) from the first Read call after the
    44  // last piece of data is read. DataErrReader wraps a Reader and changes its
    45  // behavior so the final error is returned along with the final data, instead
    46  // of in the first call after the final data.
    47  func DataErrReader(r io.Reader) io.Reader { return &dataErrReader{r, nil, make([]byte, 1024)} }
    48  
    49  type dataErrReader struct {
    50  	r      io.Reader
    51  	unread []byte
    52  	data   []byte
    53  }
    54  
    55  func (r *dataErrReader) Read(p []byte) (n int, err error) {
    56  	// loop because first call needs two reads:
    57  	// one to get data and a second to look for an error.
    58  	for {
    59  		if len(r.unread) == 0 {
    60  			n1, err1 := r.r.Read(r.data)
    61  			r.unread = r.data[0:n1]
    62  			err = err1
    63  		}
    64  		if n > 0 || err != nil {
    65  			break
    66  		}
    67  		n = copy(p, r.unread)
    68  		r.unread = r.unread[n:]
    69  	}
    70  	return
    71  }
    72  
    73  // ErrTimeout is a fake timeout error.
    74  var ErrTimeout = errors.New("timeout")
    75  
    76  // TimeoutReader returns [ErrTimeout] on the second read
    77  // with no data. Subsequent calls to read succeed.
    78  func TimeoutReader(r io.Reader) io.Reader { return &timeoutReader{r, 0} }
    79  
    80  type timeoutReader struct {
    81  	r     io.Reader
    82  	count int
    83  }
    84  
    85  func (r *timeoutReader) Read(p []byte) (int, error) {
    86  	r.count++
    87  	if r.count == 2 {
    88  		return 0, ErrTimeout
    89  	}
    90  	return r.r.Read(p)
    91  }
    92  
    93  // ErrReader returns an [io.Reader] that returns 0, err from all Read calls.
    94  func ErrReader(err error) io.Reader {
    95  	return &errReader{err: err}
    96  }
    97  
    98  type errReader struct {
    99  	err error
   100  }
   101  
   102  func (r *errReader) Read(p []byte) (int, error) {
   103  	return 0, r.err
   104  }
   105  
   106  type smallByteReader struct {
   107  	r   io.Reader
   108  	off int
   109  	n   int
   110  }
   111  
   112  func (r *smallByteReader) Read(p []byte) (int, error) {
   113  	if len(p) == 0 {
   114  		return 0, nil
   115  	}
   116  	r.n = r.n%3 + 1
   117  	n := r.n
   118  	if n > len(p) {
   119  		n = len(p)
   120  	}
   121  	n, err := r.r.Read(p[0:n])
   122  	if err != nil && err != io.EOF {
   123  		err = fmt.Errorf("Read(%d bytes at offset %d): %v", n, r.off, err)
   124  	}
   125  	r.off += n
   126  	return n, err
   127  }
   128  
   129  // TestReader tests that reading from r returns the expected file content.
   130  // It does reads of different sizes, until EOF.
   131  // If r implements [io.ReaderAt] or [io.Seeker], TestReader also checks
   132  // that those operations behave as they should.
   133  //
   134  // If TestReader finds any misbehaviors, it returns an error reporting them.
   135  // The error text may span multiple lines.
   136  func TestReader(r io.Reader, content []byte) error {
   137  	if len(content) > 0 {
   138  		n, err := r.Read(nil)
   139  		if n != 0 || err != nil {
   140  			return fmt.Errorf("Read(0) = %d, %v, want 0, nil", n, err)
   141  		}
   142  	}
   143  
   144  	data, err := io.ReadAll(&smallByteReader{r: r})
   145  	if err != nil {
   146  		return err
   147  	}
   148  	if !bytes.Equal(data, content) {
   149  		return fmt.Errorf("ReadAll(small amounts) = %q\n\twant %q", data, content)
   150  	}
   151  	n, err := r.Read(make([]byte, 10))
   152  	if n != 0 || err != io.EOF {
   153  		return fmt.Errorf("Read(10) at EOF = %v, %v, want 0, EOF", n, err)
   154  	}
   155  
   156  	if r, ok := r.(io.ReadSeeker); ok {
   157  		// Seek(0, 1) should report the current file position (EOF).
   158  		if off, err := r.Seek(0, 1); off != int64(len(content)) || err != nil {
   159  			return fmt.Errorf("Seek(0, 1) from EOF = %d, %v, want %d, nil", off, err, len(content))
   160  		}
   161  
   162  		// Seek backward partway through file, in two steps.
   163  		// If middle == 0, len(content) == 0, can't use the -1 and +1 seeks.
   164  		middle := len(content) - len(content)/3
   165  		if middle > 0 {
   166  			if off, err := r.Seek(-1, 1); off != int64(len(content)-1) || err != nil {
   167  				return fmt.Errorf("Seek(-1, 1) from EOF = %d, %v, want %d, nil", -off, err, len(content)-1)
   168  			}
   169  			if off, err := r.Seek(int64(-len(content)/3), 1); off != int64(middle-1) || err != nil {
   170  				return fmt.Errorf("Seek(%d, 1) from %d = %d, %v, want %d, nil", -len(content)/3, len(content)-1, off, err, middle-1)
   171  			}
   172  			if off, err := r.Seek(+1, 1); off != int64(middle) || err != nil {
   173  				return fmt.Errorf("Seek(+1, 1) from %d = %d, %v, want %d, nil", middle-1, off, err, middle)
   174  			}
   175  		}
   176  
   177  		// Seek(0, 1) should report the current file position (middle).
   178  		if off, err := r.Seek(0, 1); off != int64(middle) || err != nil {
   179  			return fmt.Errorf("Seek(0, 1) from %d = %d, %v, want %d, nil", middle, off, err, middle)
   180  		}
   181  
   182  		// Reading forward should return the last part of the file.
   183  		data, err := io.ReadAll(&smallByteReader{r: r})
   184  		if err != nil {
   185  			return fmt.Errorf("ReadAll from offset %d: %v", middle, err)
   186  		}
   187  		if !bytes.Equal(data, content[middle:]) {
   188  			return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", middle, data, content[middle:])
   189  		}
   190  
   191  		// Seek relative to end of file, but start elsewhere.
   192  		if off, err := r.Seek(int64(middle/2), 0); off != int64(middle/2) || err != nil {
   193  			return fmt.Errorf("Seek(%d, 0) from EOF = %d, %v, want %d, nil", middle/2, off, err, middle/2)
   194  		}
   195  		if off, err := r.Seek(int64(-len(content)/3), 2); off != int64(middle) || err != nil {
   196  			return fmt.Errorf("Seek(%d, 2) from %d = %d, %v, want %d, nil", -len(content)/3, middle/2, off, err, middle)
   197  		}
   198  
   199  		// Reading forward should return the last part of the file (again).
   200  		data, err = io.ReadAll(&smallByteReader{r: r})
   201  		if err != nil {
   202  			return fmt.Errorf("ReadAll from offset %d: %v", middle, err)
   203  		}
   204  		if !bytes.Equal(data, content[middle:]) {
   205  			return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", middle, data, content[middle:])
   206  		}
   207  
   208  		// Absolute seek & read forward.
   209  		if off, err := r.Seek(int64(middle/2), 0); off != int64(middle/2) || err != nil {
   210  			return fmt.Errorf("Seek(%d, 0) from EOF = %d, %v, want %d, nil", middle/2, off, err, middle/2)
   211  		}
   212  		data, err = io.ReadAll(r)
   213  		if err != nil {
   214  			return fmt.Errorf("ReadAll from offset %d: %v", middle/2, err)
   215  		}
   216  		if !bytes.Equal(data, content[middle/2:]) {
   217  			return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", middle/2, data, content[middle/2:])
   218  		}
   219  	}
   220  
   221  	if r, ok := r.(io.ReaderAt); ok {
   222  		data := make([]byte, len(content), len(content)+1)
   223  		for i := range data {
   224  			data[i] = 0xfe
   225  		}
   226  		n, err := r.ReadAt(data, 0)
   227  		if n != len(data) || err != nil && err != io.EOF {
   228  			return fmt.Errorf("ReadAt(%d, 0) = %v, %v, want %d, nil or EOF", len(data), n, err, len(data))
   229  		}
   230  		if !bytes.Equal(data, content) {
   231  			return fmt.Errorf("ReadAt(%d, 0) = %q\n\twant %q", len(data), data, content)
   232  		}
   233  
   234  		n, err = r.ReadAt(data[:1], int64(len(data)))
   235  		if n != 0 || err != io.EOF {
   236  			return fmt.Errorf("ReadAt(1, %d) = %v, %v, want 0, EOF", len(data), n, err)
   237  		}
   238  
   239  		for i := range data {
   240  			data[i] = 0xfe
   241  		}
   242  		n, err = r.ReadAt(data[:cap(data)], 0)
   243  		if n != len(data) || err != io.EOF {
   244  			return fmt.Errorf("ReadAt(%d, 0) = %v, %v, want %d, EOF", cap(data), n, err, len(data))
   245  		}
   246  		if !bytes.Equal(data, content) {
   247  			return fmt.Errorf("ReadAt(%d, 0) = %q\n\twant %q", len(data), data, content)
   248  		}
   249  
   250  		for i := range data {
   251  			data[i] = 0xfe
   252  		}
   253  		for i := range data {
   254  			n, err = r.ReadAt(data[i:i+1], int64(i))
   255  			if n != 1 || err != nil && (i != len(data)-1 || err != io.EOF) {
   256  				want := "nil"
   257  				if i == len(data)-1 {
   258  					want = "nil or EOF"
   259  				}
   260  				return fmt.Errorf("ReadAt(1, %d) = %v, %v, want 1, %s", i, n, err, want)
   261  			}
   262  			if data[i] != content[i] {
   263  				return fmt.Errorf("ReadAt(1, %d) = %q want %q", i, data[i:i+1], content[i:i+1])
   264  			}
   265  		}
   266  	}
   267  	return nil
   268  }
   269  

View as plain text