Source file src/net/http/responsecontroller_test.go

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http_test
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	. "net/http"
    12  	"os"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  )
    17  
    18  func TestResponseControllerFlush(t *testing.T) { run(t, testResponseControllerFlush) }
    19  func testResponseControllerFlush(t *testing.T, mode testMode) {
    20  	continuec := make(chan struct{})
    21  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
    22  		ctl := NewResponseController(w)
    23  		w.Write([]byte("one"))
    24  		if err := ctl.Flush(); err != nil {
    25  			t.Errorf("ctl.Flush() = %v, want nil", err)
    26  			return
    27  		}
    28  		<-continuec
    29  		w.Write([]byte("two"))
    30  	}))
    31  
    32  	res, err := cst.c.Get(cst.ts.URL)
    33  	if err != nil {
    34  		t.Fatalf("unexpected connection error: %v", err)
    35  	}
    36  	defer res.Body.Close()
    37  
    38  	buf := make([]byte, 16)
    39  	n, err := res.Body.Read(buf)
    40  	close(continuec)
    41  	if err != nil || string(buf[:n]) != "one" {
    42  		t.Fatalf("Body.Read = %q, %v, want %q, nil", string(buf[:n]), err, "one")
    43  	}
    44  
    45  	got, err := io.ReadAll(res.Body)
    46  	if err != nil || string(got) != "two" {
    47  		t.Fatalf("Body.Read = %q, %v, want %q, nil", string(got), err, "two")
    48  	}
    49  }
    50  
    51  func TestResponseControllerHijack(t *testing.T) { run(t, testResponseControllerHijack) }
    52  func testResponseControllerHijack(t *testing.T, mode testMode) {
    53  	const header = "X-Header"
    54  	const value = "set"
    55  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
    56  		ctl := NewResponseController(w)
    57  		c, _, err := ctl.Hijack()
    58  		if mode == http2Mode {
    59  			if err == nil {
    60  				t.Errorf("ctl.Hijack = nil, want error")
    61  			}
    62  			w.Header().Set(header, value)
    63  			return
    64  		}
    65  		if err != nil {
    66  			t.Errorf("ctl.Hijack = _, _, %v, want _, _, nil", err)
    67  			return
    68  		}
    69  		fmt.Fprintf(c, "HTTP/1.0 200 OK\r\n%v: %v\r\nContent-Length: 0\r\n\r\n", header, value)
    70  	}))
    71  	res, err := cst.c.Get(cst.ts.URL)
    72  	if err != nil {
    73  		t.Fatal(err)
    74  	}
    75  	if got, want := res.Header.Get(header), value; got != want {
    76  		t.Errorf("response header %q = %q, want %q", header, got, want)
    77  	}
    78  }
    79  
    80  func TestResponseControllerSetPastWriteDeadline(t *testing.T) {
    81  	run(t, testResponseControllerSetPastWriteDeadline)
    82  }
    83  func testResponseControllerSetPastWriteDeadline(t *testing.T, mode testMode) {
    84  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
    85  		ctl := NewResponseController(w)
    86  		w.Write([]byte("one"))
    87  		if err := ctl.Flush(); err != nil {
    88  			t.Errorf("before setting deadline: ctl.Flush() = %v, want nil", err)
    89  		}
    90  		if err := ctl.SetWriteDeadline(time.Now().Add(-10 * time.Second)); err != nil {
    91  			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
    92  		}
    93  
    94  		w.Write([]byte("two"))
    95  		if err := ctl.Flush(); err == nil {
    96  			t.Errorf("after setting deadline: ctl.Flush() = nil, want non-nil")
    97  		}
    98  		// Connection errors are sticky, so resetting the deadline does not permit
    99  		// making more progress. We might want to change this in the future, but verify
   100  		// the current behavior for now. If we do change this, we'll want to make sure
   101  		// to do so only for writing the response body, not headers.
   102  		if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Hour)); err != nil {
   103  			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
   104  		}
   105  		w.Write([]byte("three"))
   106  		if err := ctl.Flush(); err == nil {
   107  			t.Errorf("after resetting deadline: ctl.Flush() = nil, want non-nil")
   108  		}
   109  	}))
   110  
   111  	res, err := cst.c.Get(cst.ts.URL)
   112  	if err != nil {
   113  		t.Fatalf("unexpected connection error: %v", err)
   114  	}
   115  	defer res.Body.Close()
   116  	b, _ := io.ReadAll(res.Body)
   117  	if string(b) != "one" {
   118  		t.Errorf("unexpected body: %q", string(b))
   119  	}
   120  }
   121  
   122  func TestResponseControllerSetFutureWriteDeadline(t *testing.T) {
   123  	run(t, testResponseControllerSetFutureWriteDeadline)
   124  }
   125  func testResponseControllerSetFutureWriteDeadline(t *testing.T, mode testMode) {
   126  	errc := make(chan error, 1)
   127  	startwritec := make(chan struct{})
   128  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   129  		ctl := NewResponseController(w)
   130  		w.WriteHeader(200)
   131  		if err := ctl.Flush(); err != nil {
   132  			t.Errorf("ctl.Flush() = %v, want nil", err)
   133  		}
   134  		<-startwritec // don't set the deadline until the client reads response headers
   135  		if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
   136  			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
   137  		}
   138  		_, err := io.Copy(w, neverEnding('a'))
   139  		errc <- err
   140  	}))
   141  
   142  	res, err := cst.c.Get(cst.ts.URL)
   143  	close(startwritec)
   144  	if err != nil {
   145  		t.Fatalf("unexpected connection error: %v", err)
   146  	}
   147  	defer res.Body.Close()
   148  	_, err = io.Copy(io.Discard, res.Body)
   149  	if err == nil {
   150  		t.Errorf("client reading from truncated request body: got nil error, want non-nil")
   151  	}
   152  	err = <-errc // io.Copy error
   153  	if !errors.Is(err, os.ErrDeadlineExceeded) {
   154  		t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
   155  	}
   156  }
   157  
   158  func TestResponseControllerSetPastReadDeadline(t *testing.T) {
   159  	run(t, testResponseControllerSetPastReadDeadline)
   160  }
   161  func testResponseControllerSetPastReadDeadline(t *testing.T, mode testMode) {
   162  	readc := make(chan struct{})
   163  	donec := make(chan struct{})
   164  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   165  		defer close(donec)
   166  		ctl := NewResponseController(w)
   167  		b := make([]byte, 3)
   168  		n, err := io.ReadFull(r.Body, b)
   169  		b = b[:n]
   170  		if err != nil || string(b) != "one" {
   171  			t.Errorf("before setting read deadline: Read = %v, %q, want nil, %q", err, string(b), "one")
   172  			return
   173  		}
   174  		if err := ctl.SetReadDeadline(time.Now()); err != nil {
   175  			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
   176  			return
   177  		}
   178  		b, err = io.ReadAll(r.Body)
   179  		if err == nil || string(b) != "" {
   180  			t.Errorf("after setting read deadline: Read = %q, nil, want error", string(b))
   181  		}
   182  		close(readc)
   183  		// Connection errors are sticky, so resetting the deadline does not permit
   184  		// making more progress. We might want to change this in the future, but verify
   185  		// the current behavior for now.
   186  		if err := ctl.SetReadDeadline(time.Time{}); err != nil {
   187  			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
   188  			return
   189  		}
   190  		b, err = io.ReadAll(r.Body)
   191  		if err == nil {
   192  			t.Errorf("after resetting read deadline: Read = %q, nil, want error", string(b))
   193  		}
   194  	}))
   195  
   196  	pr, pw := io.Pipe()
   197  	var wg sync.WaitGroup
   198  	wg.Add(1)
   199  	go func() {
   200  		defer wg.Done()
   201  		defer pw.Close()
   202  		pw.Write([]byte("one"))
   203  		select {
   204  		case <-readc:
   205  		case <-donec:
   206  			select {
   207  			case <-readc:
   208  			default:
   209  				t.Errorf("server handler unexpectedly exited without closing readc")
   210  				return
   211  			}
   212  		}
   213  		pw.Write([]byte("two"))
   214  	}()
   215  	defer wg.Wait()
   216  	res, err := cst.c.Post(cst.ts.URL, "text/foo", pr)
   217  	if err == nil {
   218  		defer res.Body.Close()
   219  	}
   220  }
   221  
   222  func TestResponseControllerSetFutureReadDeadline(t *testing.T) {
   223  	run(t, testResponseControllerSetFutureReadDeadline)
   224  }
   225  func testResponseControllerSetFutureReadDeadline(t *testing.T, mode testMode) {
   226  	respBody := "response body"
   227  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
   228  		ctl := NewResponseController(w)
   229  		if err := ctl.SetReadDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
   230  			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
   231  		}
   232  		_, err := io.Copy(io.Discard, req.Body)
   233  		if !errors.Is(err, os.ErrDeadlineExceeded) {
   234  			t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
   235  		}
   236  		w.Write([]byte(respBody))
   237  	}))
   238  	pr, pw := io.Pipe()
   239  	res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
   240  	if err != nil {
   241  		t.Fatal(err)
   242  	}
   243  	defer res.Body.Close()
   244  	got, err := io.ReadAll(res.Body)
   245  	if string(got) != respBody || err != nil {
   246  		t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
   247  	}
   248  	pw.Close()
   249  }
   250  
   251  type wrapWriter struct {
   252  	ResponseWriter
   253  }
   254  
   255  func (w wrapWriter) Unwrap() ResponseWriter {
   256  	return w.ResponseWriter
   257  }
   258  
   259  func TestWrappedResponseController(t *testing.T) { run(t, testWrappedResponseController) }
   260  func testWrappedResponseController(t *testing.T, mode testMode) {
   261  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   262  		w = wrapWriter{w}
   263  		ctl := NewResponseController(w)
   264  		if err := ctl.Flush(); err != nil {
   265  			t.Errorf("ctl.Flush() = %v, want nil", err)
   266  		}
   267  		if err := ctl.SetReadDeadline(time.Time{}); err != nil {
   268  			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
   269  		}
   270  		if err := ctl.SetWriteDeadline(time.Time{}); err != nil {
   271  			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
   272  		}
   273  	}))
   274  	res, err := cst.c.Get(cst.ts.URL)
   275  	if err != nil {
   276  		t.Fatalf("unexpected connection error: %v", err)
   277  	}
   278  	io.Copy(io.Discard, res.Body)
   279  	defer res.Body.Close()
   280  }
   281  
   282  func TestResponseControllerEnableFullDuplex(t *testing.T) {
   283  	run(t, testResponseControllerEnableFullDuplex)
   284  }
   285  func testResponseControllerEnableFullDuplex(t *testing.T, mode testMode) {
   286  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
   287  		ctl := NewResponseController(w)
   288  		if err := ctl.EnableFullDuplex(); err != nil {
   289  			// TODO: Drop test for HTTP/2 when x/net is updated to support
   290  			// EnableFullDuplex. Since HTTP/2 supports full duplex by default,
   291  			// the rest of the test is fine; it's just the EnableFullDuplex call
   292  			// that fails.
   293  			if mode != http2Mode {
   294  				t.Errorf("ctl.EnableFullDuplex() = %v, want nil", err)
   295  			}
   296  		}
   297  		w.WriteHeader(200)
   298  		ctl.Flush()
   299  		for {
   300  			var buf [1]byte
   301  			n, err := req.Body.Read(buf[:])
   302  			if n != 1 || err != nil {
   303  				break
   304  			}
   305  			w.Write(buf[:])
   306  			ctl.Flush()
   307  		}
   308  	}))
   309  	pr, pw := io.Pipe()
   310  	res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
   311  	if err != nil {
   312  		t.Fatal(err)
   313  	}
   314  	defer res.Body.Close()
   315  	for i := byte(0); i < 10; i++ {
   316  		if _, err := pw.Write([]byte{i}); err != nil {
   317  			t.Fatalf("Write: %v", err)
   318  		}
   319  		var buf [1]byte
   320  		if n, err := res.Body.Read(buf[:]); n != 1 || err != nil {
   321  			t.Fatalf("Read: %v, %v", n, err)
   322  		}
   323  		if buf[0] != i {
   324  			t.Fatalf("read byte %v, want %v", buf[0], i)
   325  		}
   326  	}
   327  	pw.Close()
   328  }
   329  

View as plain text