Source file src/net/http/httputil/reverseproxy_test.go

Documentation: net/http/httputil

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"context"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"io/ioutil"
    17  	"log"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"net/url"
    21  	"os"
    22  	"reflect"
    23  	"sort"
    24  	"strconv"
    25  	"strings"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  )
    30  
    31  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    32  
    33  func init() {
    34  	inOurTests = true
    35  	hopHeaders = append(hopHeaders, fakeHopHeader)
    36  }
    37  
    38  func TestReverseProxy(t *testing.T) {
    39  	const backendResponse = "I am the backend"
    40  	const backendStatus = 404
    41  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    42  		if r.Method == "GET" && r.FormValue("mode") == "hangup" {
    43  			c, _, _ := w.(http.Hijacker).Hijack()
    44  			c.Close()
    45  			return
    46  		}
    47  		if len(r.TransferEncoding) > 0 {
    48  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    49  		}
    50  		if r.Header.Get("X-Forwarded-For") == "" {
    51  			t.Errorf("didn't get X-Forwarded-For header")
    52  		}
    53  		if c := r.Header.Get("Connection"); c != "" {
    54  			t.Errorf("handler got Connection header value %q", c)
    55  		}
    56  		if c := r.Header.Get("Te"); c != "trailers" {
    57  			t.Errorf("handler got Te header value %q; want 'trailers'", c)
    58  		}
    59  		if c := r.Header.Get("Upgrade"); c != "" {
    60  			t.Errorf("handler got Upgrade header value %q", c)
    61  		}
    62  		if c := r.Header.Get("Proxy-Connection"); c != "" {
    63  			t.Errorf("handler got Proxy-Connection header value %q", c)
    64  		}
    65  		if g, e := r.Host, "some-name"; g != e {
    66  			t.Errorf("backend got Host header %q, want %q", g, e)
    67  		}
    68  		w.Header().Set("Trailers", "not a special header field name")
    69  		w.Header().Set("Trailer", "X-Trailer")
    70  		w.Header().Set("X-Foo", "bar")
    71  		w.Header().Set("Upgrade", "foo")
    72  		w.Header().Set(fakeHopHeader, "foo")
    73  		w.Header().Add("X-Multi-Value", "foo")
    74  		w.Header().Add("X-Multi-Value", "bar")
    75  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    76  		w.WriteHeader(backendStatus)
    77  		w.Write([]byte(backendResponse))
    78  		w.Header().Set("X-Trailer", "trailer_value")
    79  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
    80  	}))
    81  	defer backend.Close()
    82  	backendURL, err := url.Parse(backend.URL)
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    87  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
    88  	frontend := httptest.NewServer(proxyHandler)
    89  	defer frontend.Close()
    90  	frontendClient := frontend.Client()
    91  
    92  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    93  	getReq.Host = "some-name"
    94  	getReq.Header.Set("Connection", "close")
    95  	getReq.Header.Set("Te", "trailers")
    96  	getReq.Header.Set("Proxy-Connection", "should be deleted")
    97  	getReq.Header.Set("Upgrade", "foo")
    98  	getReq.Close = true
    99  	res, err := frontendClient.Do(getReq)
   100  	if err != nil {
   101  		t.Fatalf("Get: %v", err)
   102  	}
   103  	if g, e := res.StatusCode, backendStatus; g != e {
   104  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   105  	}
   106  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
   107  		t.Errorf("got X-Foo %q; expected %q", g, e)
   108  	}
   109  	if c := res.Header.Get(fakeHopHeader); c != "" {
   110  		t.Errorf("got %s header value %q", fakeHopHeader, c)
   111  	}
   112  	if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
   113  		t.Errorf("header Trailers = %q; want %q", g, e)
   114  	}
   115  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
   116  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
   117  	}
   118  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
   119  		t.Fatalf("got %d SetCookies, want %d", g, e)
   120  	}
   121  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
   122  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
   123  	}
   124  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
   125  		t.Errorf("unexpected cookie %q", cookie.Name)
   126  	}
   127  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   128  	if g, e := string(bodyBytes), backendResponse; g != e {
   129  		t.Errorf("got body %q; expected %q", g, e)
   130  	}
   131  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   132  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   133  	}
   134  	if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
   135  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
   136  	}
   137  
   138  	// Test that a backend failing to be reached or one which doesn't return
   139  	// a response results in a StatusBadGateway.
   140  	getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
   141  	getReq.Close = true
   142  	res, err = frontendClient.Do(getReq)
   143  	if err != nil {
   144  		t.Fatal(err)
   145  	}
   146  	res.Body.Close()
   147  	if res.StatusCode != http.StatusBadGateway {
   148  		t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
   149  	}
   150  
   151  }
   152  
   153  // Issue 16875: remove any proxied headers mentioned in the "Connection"
   154  // header value.
   155  func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
   156  	const fakeConnectionToken = "X-Fake-Connection-Token"
   157  	const backendResponse = "I am the backend"
   158  
   159  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   160  	// in the Request's Connection header.
   161  	const someConnHeader = "X-Some-Conn-Header"
   162  
   163  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   164  		if c := r.Header.Get("Connection"); c != "" {
   165  			t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   166  		}
   167  		if c := r.Header.Get(fakeConnectionToken); c != "" {
   168  			t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   169  		}
   170  		if c := r.Header.Get(someConnHeader); c != "" {
   171  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   172  		}
   173  		w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
   174  		w.Header().Add("Connection", someConnHeader)
   175  		w.Header().Set(someConnHeader, "should be deleted")
   176  		w.Header().Set(fakeConnectionToken, "should be deleted")
   177  		io.WriteString(w, backendResponse)
   178  	}))
   179  	defer backend.Close()
   180  	backendURL, err := url.Parse(backend.URL)
   181  	if err != nil {
   182  		t.Fatal(err)
   183  	}
   184  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   185  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   186  		proxyHandler.ServeHTTP(w, r)
   187  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   188  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   189  		}
   190  		if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
   191  			t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
   192  		}
   193  		c := r.Header["Connection"]
   194  		var cf []string
   195  		for _, f := range c {
   196  			for _, sf := range strings.Split(f, ",") {
   197  				if sf = strings.TrimSpace(sf); sf != "" {
   198  					cf = append(cf, sf)
   199  				}
   200  			}
   201  		}
   202  		sort.Strings(cf)
   203  		expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
   204  		sort.Strings(expectedValues)
   205  		if !reflect.DeepEqual(cf, expectedValues) {
   206  			t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
   207  		}
   208  	}))
   209  	defer frontend.Close()
   210  
   211  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   212  	getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
   213  	getReq.Header.Add("Connection", someConnHeader)
   214  	getReq.Header.Set(someConnHeader, "should be deleted")
   215  	getReq.Header.Set(fakeConnectionToken, "should be deleted")
   216  	res, err := frontend.Client().Do(getReq)
   217  	if err != nil {
   218  		t.Fatalf("Get: %v", err)
   219  	}
   220  	defer res.Body.Close()
   221  	bodyBytes, err := ioutil.ReadAll(res.Body)
   222  	if err != nil {
   223  		t.Fatalf("reading body: %v", err)
   224  	}
   225  	if got, want := string(bodyBytes), backendResponse; got != want {
   226  		t.Errorf("got body %q; want %q", got, want)
   227  	}
   228  	if c := res.Header.Get("Connection"); c != "" {
   229  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   230  	}
   231  	if c := res.Header.Get(someConnHeader); c != "" {
   232  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   233  	}
   234  	if c := res.Header.Get(fakeConnectionToken); c != "" {
   235  		t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   236  	}
   237  }
   238  
   239  func TestXForwardedFor(t *testing.T) {
   240  	const prevForwardedFor = "client ip"
   241  	const backendResponse = "I am the backend"
   242  	const backendStatus = 404
   243  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   244  		if r.Header.Get("X-Forwarded-For") == "" {
   245  			t.Errorf("didn't get X-Forwarded-For header")
   246  		}
   247  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   248  			t.Errorf("X-Forwarded-For didn't contain prior data")
   249  		}
   250  		w.WriteHeader(backendStatus)
   251  		w.Write([]byte(backendResponse))
   252  	}))
   253  	defer backend.Close()
   254  	backendURL, err := url.Parse(backend.URL)
   255  	if err != nil {
   256  		t.Fatal(err)
   257  	}
   258  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   259  	frontend := httptest.NewServer(proxyHandler)
   260  	defer frontend.Close()
   261  
   262  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   263  	getReq.Host = "some-name"
   264  	getReq.Header.Set("Connection", "close")
   265  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   266  	getReq.Close = true
   267  	res, err := frontend.Client().Do(getReq)
   268  	if err != nil {
   269  		t.Fatalf("Get: %v", err)
   270  	}
   271  	if g, e := res.StatusCode, backendStatus; g != e {
   272  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   273  	}
   274  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   275  	if g, e := string(bodyBytes), backendResponse; g != e {
   276  		t.Errorf("got body %q; expected %q", g, e)
   277  	}
   278  }
   279  
   280  var proxyQueryTests = []struct {
   281  	baseSuffix string // suffix to add to backend URL
   282  	reqSuffix  string // suffix to add to frontend's request URL
   283  	want       string // what backend should see for final request URL (without ?)
   284  }{
   285  	{"", "", ""},
   286  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   287  	{"", "?us=er", "us=er"},
   288  	{"?sta=tic", "", "sta=tic"},
   289  }
   290  
   291  func TestReverseProxyQuery(t *testing.T) {
   292  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   293  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   294  		w.Write([]byte("hi"))
   295  	}))
   296  	defer backend.Close()
   297  
   298  	for i, tt := range proxyQueryTests {
   299  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   300  		if err != nil {
   301  			t.Fatal(err)
   302  		}
   303  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   304  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   305  		req.Close = true
   306  		res, err := frontend.Client().Do(req)
   307  		if err != nil {
   308  			t.Fatalf("%d. Get: %v", i, err)
   309  		}
   310  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   311  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   312  		}
   313  		res.Body.Close()
   314  		frontend.Close()
   315  	}
   316  }
   317  
   318  func TestReverseProxyFlushInterval(t *testing.T) {
   319  	const expected = "hi"
   320  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   321  		w.Write([]byte(expected))
   322  	}))
   323  	defer backend.Close()
   324  
   325  	backendURL, err := url.Parse(backend.URL)
   326  	if err != nil {
   327  		t.Fatal(err)
   328  	}
   329  
   330  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   331  	proxyHandler.FlushInterval = time.Microsecond
   332  
   333  	frontend := httptest.NewServer(proxyHandler)
   334  	defer frontend.Close()
   335  
   336  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   337  	req.Close = true
   338  	res, err := frontend.Client().Do(req)
   339  	if err != nil {
   340  		t.Fatalf("Get: %v", err)
   341  	}
   342  	defer res.Body.Close()
   343  	if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
   344  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   345  	}
   346  }
   347  
   348  func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
   349  	const expected = "hi"
   350  	stopCh := make(chan struct{})
   351  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   352  		w.Header().Add("MyHeader", expected)
   353  		w.WriteHeader(200)
   354  		w.(http.Flusher).Flush()
   355  		<-stopCh
   356  	}))
   357  	defer backend.Close()
   358  	defer close(stopCh)
   359  
   360  	backendURL, err := url.Parse(backend.URL)
   361  	if err != nil {
   362  		t.Fatal(err)
   363  	}
   364  
   365  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   366  	proxyHandler.FlushInterval = time.Microsecond
   367  
   368  	frontend := httptest.NewServer(proxyHandler)
   369  	defer frontend.Close()
   370  
   371  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   372  	req.Close = true
   373  
   374  	ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
   375  	defer cancel()
   376  	req = req.WithContext(ctx)
   377  
   378  	res, err := frontend.Client().Do(req)
   379  	if err != nil {
   380  		t.Fatalf("Get: %v", err)
   381  	}
   382  	defer res.Body.Close()
   383  
   384  	if res.Header.Get("MyHeader") != expected {
   385  		t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
   386  	}
   387  }
   388  
   389  func TestReverseProxyCancelation(t *testing.T) {
   390  	const backendResponse = "I am the backend"
   391  
   392  	reqInFlight := make(chan struct{})
   393  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   394  		close(reqInFlight) // cause the client to cancel its request
   395  
   396  		select {
   397  		case <-time.After(10 * time.Second):
   398  			// Note: this should only happen in broken implementations, and the
   399  			// closenotify case should be instantaneous.
   400  			t.Error("Handler never saw CloseNotify")
   401  			return
   402  		case <-w.(http.CloseNotifier).CloseNotify():
   403  		}
   404  
   405  		w.WriteHeader(http.StatusOK)
   406  		w.Write([]byte(backendResponse))
   407  	}))
   408  
   409  	defer backend.Close()
   410  
   411  	backend.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
   412  
   413  	backendURL, err := url.Parse(backend.URL)
   414  	if err != nil {
   415  		t.Fatal(err)
   416  	}
   417  
   418  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   419  
   420  	// Discards errors of the form:
   421  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   422  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0)
   423  
   424  	frontend := httptest.NewServer(proxyHandler)
   425  	defer frontend.Close()
   426  	frontendClient := frontend.Client()
   427  
   428  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   429  	go func() {
   430  		<-reqInFlight
   431  		frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
   432  	}()
   433  	res, err := frontendClient.Do(getReq)
   434  	if res != nil {
   435  		t.Errorf("got response %v; want nil", res.Status)
   436  	}
   437  	if err == nil {
   438  		// This should be an error like:
   439  		// Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079:
   440  		//    use of closed network connection
   441  		t.Error("Server.Client().Do() returned nil error; want non-nil error")
   442  	}
   443  }
   444  
   445  func req(t *testing.T, v string) *http.Request {
   446  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
   447  	if err != nil {
   448  		t.Fatal(err)
   449  	}
   450  	return req
   451  }
   452  
   453  // Issue 12344
   454  func TestNilBody(t *testing.T) {
   455  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   456  		w.Write([]byte("hi"))
   457  	}))
   458  	defer backend.Close()
   459  
   460  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   461  		backURL, _ := url.Parse(backend.URL)
   462  		rp := NewSingleHostReverseProxy(backURL)
   463  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
   464  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
   465  		rp.ServeHTTP(w, r)
   466  	}))
   467  	defer frontend.Close()
   468  
   469  	res, err := http.Get(frontend.URL)
   470  	if err != nil {
   471  		t.Fatal(err)
   472  	}
   473  	defer res.Body.Close()
   474  	slurp, err := ioutil.ReadAll(res.Body)
   475  	if err != nil {
   476  		t.Fatal(err)
   477  	}
   478  	if string(slurp) != "hi" {
   479  		t.Errorf("Got %q; want %q", slurp, "hi")
   480  	}
   481  }
   482  
   483  // Issue 15524
   484  func TestUserAgentHeader(t *testing.T) {
   485  	const explicitUA = "explicit UA"
   486  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   487  		if r.URL.Path == "/noua" {
   488  			if c := r.Header.Get("User-Agent"); c != "" {
   489  				t.Errorf("handler got non-empty User-Agent header %q", c)
   490  			}
   491  			return
   492  		}
   493  		if c := r.Header.Get("User-Agent"); c != explicitUA {
   494  			t.Errorf("handler got unexpected User-Agent header %q", c)
   495  		}
   496  	}))
   497  	defer backend.Close()
   498  	backendURL, err := url.Parse(backend.URL)
   499  	if err != nil {
   500  		t.Fatal(err)
   501  	}
   502  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   503  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   504  	frontend := httptest.NewServer(proxyHandler)
   505  	defer frontend.Close()
   506  	frontendClient := frontend.Client()
   507  
   508  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   509  	getReq.Header.Set("User-Agent", explicitUA)
   510  	getReq.Close = true
   511  	res, err := frontendClient.Do(getReq)
   512  	if err != nil {
   513  		t.Fatalf("Get: %v", err)
   514  	}
   515  	res.Body.Close()
   516  
   517  	getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
   518  	getReq.Header.Set("User-Agent", "")
   519  	getReq.Close = true
   520  	res, err = frontendClient.Do(getReq)
   521  	if err != nil {
   522  		t.Fatalf("Get: %v", err)
   523  	}
   524  	res.Body.Close()
   525  }
   526  
   527  type bufferPool struct {
   528  	get func() []byte
   529  	put func([]byte)
   530  }
   531  
   532  func (bp bufferPool) Get() []byte  { return bp.get() }
   533  func (bp bufferPool) Put(v []byte) { bp.put(v) }
   534  
   535  func TestReverseProxyGetPutBuffer(t *testing.T) {
   536  	const msg = "hi"
   537  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   538  		io.WriteString(w, msg)
   539  	}))
   540  	defer backend.Close()
   541  
   542  	backendURL, err := url.Parse(backend.URL)
   543  	if err != nil {
   544  		t.Fatal(err)
   545  	}
   546  
   547  	var (
   548  		mu  sync.Mutex
   549  		log []string
   550  	)
   551  	addLog := func(event string) {
   552  		mu.Lock()
   553  		defer mu.Unlock()
   554  		log = append(log, event)
   555  	}
   556  	rp := NewSingleHostReverseProxy(backendURL)
   557  	const size = 1234
   558  	rp.BufferPool = bufferPool{
   559  		get: func() []byte {
   560  			addLog("getBuf")
   561  			return make([]byte, size)
   562  		},
   563  		put: func(p []byte) {
   564  			addLog("putBuf-" + strconv.Itoa(len(p)))
   565  		},
   566  	}
   567  	frontend := httptest.NewServer(rp)
   568  	defer frontend.Close()
   569  
   570  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   571  	req.Close = true
   572  	res, err := frontend.Client().Do(req)
   573  	if err != nil {
   574  		t.Fatalf("Get: %v", err)
   575  	}
   576  	slurp, err := ioutil.ReadAll(res.Body)
   577  	res.Body.Close()
   578  	if err != nil {
   579  		t.Fatalf("reading body: %v", err)
   580  	}
   581  	if string(slurp) != msg {
   582  		t.Errorf("msg = %q; want %q", slurp, msg)
   583  	}
   584  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
   585  	mu.Lock()
   586  	defer mu.Unlock()
   587  	if !reflect.DeepEqual(log, wantLog) {
   588  		t.Errorf("Log events = %q; want %q", log, wantLog)
   589  	}
   590  }
   591  
   592  func TestReverseProxy_Post(t *testing.T) {
   593  	const backendResponse = "I am the backend"
   594  	const backendStatus = 200
   595  	var requestBody = bytes.Repeat([]byte("a"), 1<<20)
   596  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   597  		slurp, err := ioutil.ReadAll(r.Body)
   598  		if err != nil {
   599  			t.Errorf("Backend body read = %v", err)
   600  		}
   601  		if len(slurp) != len(requestBody) {
   602  			t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
   603  		}
   604  		if !bytes.Equal(slurp, requestBody) {
   605  			t.Error("Backend read wrong request body.") // 1MB; omitting details
   606  		}
   607  		w.Write([]byte(backendResponse))
   608  	}))
   609  	defer backend.Close()
   610  	backendURL, err := url.Parse(backend.URL)
   611  	if err != nil {
   612  		t.Fatal(err)
   613  	}
   614  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   615  	frontend := httptest.NewServer(proxyHandler)
   616  	defer frontend.Close()
   617  
   618  	postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
   619  	res, err := frontend.Client().Do(postReq)
   620  	if err != nil {
   621  		t.Fatalf("Do: %v", err)
   622  	}
   623  	if g, e := res.StatusCode, backendStatus; g != e {
   624  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   625  	}
   626  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   627  	if g, e := string(bodyBytes), backendResponse; g != e {
   628  		t.Errorf("got body %q; expected %q", g, e)
   629  	}
   630  }
   631  
   632  type RoundTripperFunc func(*http.Request) (*http.Response, error)
   633  
   634  func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   635  	return fn(req)
   636  }
   637  
   638  // Issue 16036: send a Request with a nil Body when possible
   639  func TestReverseProxy_NilBody(t *testing.T) {
   640  	backendURL, _ := url.Parse("http://fake.tld/")
   641  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   642  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   643  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   644  		if req.Body != nil {
   645  			t.Error("Body != nil; want a nil Body")
   646  		}
   647  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   648  	})
   649  	frontend := httptest.NewServer(proxyHandler)
   650  	defer frontend.Close()
   651  
   652  	res, err := frontend.Client().Get(frontend.URL)
   653  	if err != nil {
   654  		t.Fatal(err)
   655  	}
   656  	defer res.Body.Close()
   657  	if res.StatusCode != 502 {
   658  		t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
   659  	}
   660  }
   661  
   662  // Issue 33142: always allocate the request headers
   663  func TestReverseProxy_AllocatedHeader(t *testing.T) {
   664  	proxyHandler := new(ReverseProxy)
   665  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   666  	proxyHandler.Director = func(*http.Request) {}         // noop
   667  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   668  		if req.Header == nil {
   669  			t.Error("Header == nil; want a non-nil Header")
   670  		}
   671  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   672  	})
   673  
   674  	proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
   675  		Method:     "GET",
   676  		URL:        &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
   677  		Proto:      "HTTP/1.0",
   678  		ProtoMajor: 1,
   679  	})
   680  }
   681  
   682  // Issue 14237. Test ModifyResponse and that an error from it
   683  // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
   684  func TestReverseProxyModifyResponse(t *testing.T) {
   685  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   686  		w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
   687  	}))
   688  	defer backendServer.Close()
   689  
   690  	rpURL, _ := url.Parse(backendServer.URL)
   691  	rproxy := NewSingleHostReverseProxy(rpURL)
   692  	rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   693  	rproxy.ModifyResponse = func(resp *http.Response) error {
   694  		if resp.Header.Get("X-Hit-Mod") != "true" {
   695  			return fmt.Errorf("tried to by-pass proxy")
   696  		}
   697  		return nil
   698  	}
   699  
   700  	frontendProxy := httptest.NewServer(rproxy)
   701  	defer frontendProxy.Close()
   702  
   703  	tests := []struct {
   704  		url      string
   705  		wantCode int
   706  	}{
   707  		{frontendProxy.URL + "/mod", http.StatusOK},
   708  		{frontendProxy.URL + "/schedule", http.StatusBadGateway},
   709  	}
   710  
   711  	for i, tt := range tests {
   712  		resp, err := http.Get(tt.url)
   713  		if err != nil {
   714  			t.Fatalf("failed to reach proxy: %v", err)
   715  		}
   716  		if g, e := resp.StatusCode, tt.wantCode; g != e {
   717  			t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
   718  		}
   719  		resp.Body.Close()
   720  	}
   721  }
   722  
   723  type failingRoundTripper struct{}
   724  
   725  func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   726  	return nil, errors.New("some error")
   727  }
   728  
   729  type staticResponseRoundTripper struct{ res *http.Response }
   730  
   731  func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   732  	return rt.res, nil
   733  }
   734  
   735  func TestReverseProxyErrorHandler(t *testing.T) {
   736  	tests := []struct {
   737  		name           string
   738  		wantCode       int
   739  		errorHandler   func(http.ResponseWriter, *http.Request, error)
   740  		transport      http.RoundTripper // defaults to failingRoundTripper
   741  		modifyResponse func(*http.Response) error
   742  	}{
   743  		{
   744  			name:     "default",
   745  			wantCode: http.StatusBadGateway,
   746  		},
   747  		{
   748  			name:         "errorhandler",
   749  			wantCode:     http.StatusTeapot,
   750  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   751  		},
   752  		{
   753  			name: "modifyresponse_noerr",
   754  			transport: staticResponseRoundTripper{
   755  				&http.Response{StatusCode: 345, Body: http.NoBody},
   756  			},
   757  			modifyResponse: func(res *http.Response) error {
   758  				res.StatusCode++
   759  				return nil
   760  			},
   761  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   762  			wantCode:     346,
   763  		},
   764  		{
   765  			name: "modifyresponse_err",
   766  			transport: staticResponseRoundTripper{
   767  				&http.Response{StatusCode: 345, Body: http.NoBody},
   768  			},
   769  			modifyResponse: func(res *http.Response) error {
   770  				res.StatusCode++
   771  				return errors.New("some error to trigger errorHandler")
   772  			},
   773  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   774  			wantCode:     http.StatusTeapot,
   775  		},
   776  	}
   777  
   778  	for _, tt := range tests {
   779  		t.Run(tt.name, func(t *testing.T) {
   780  			target := &url.URL{
   781  				Scheme: "http",
   782  				Host:   "dummy.tld",
   783  				Path:   "/",
   784  			}
   785  			rproxy := NewSingleHostReverseProxy(target)
   786  			rproxy.Transport = tt.transport
   787  			rproxy.ModifyResponse = tt.modifyResponse
   788  			if rproxy.Transport == nil {
   789  				rproxy.Transport = failingRoundTripper{}
   790  			}
   791  			rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   792  			if tt.errorHandler != nil {
   793  				rproxy.ErrorHandler = tt.errorHandler
   794  			}
   795  			frontendProxy := httptest.NewServer(rproxy)
   796  			defer frontendProxy.Close()
   797  
   798  			resp, err := http.Get(frontendProxy.URL + "/test")
   799  			if err != nil {
   800  				t.Fatalf("failed to reach proxy: %v", err)
   801  			}
   802  			if g, e := resp.StatusCode, tt.wantCode; g != e {
   803  				t.Errorf("got res.StatusCode %d; expected %d", g, e)
   804  			}
   805  			resp.Body.Close()
   806  		})
   807  	}
   808  }
   809  
   810  // Issue 16659: log errors from short read
   811  func TestReverseProxy_CopyBuffer(t *testing.T) {
   812  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   813  		out := "this call was relayed by the reverse proxy"
   814  		// Coerce a wrong content length to induce io.UnexpectedEOF
   815  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
   816  		fmt.Fprintln(w, out)
   817  	}))
   818  	defer backendServer.Close()
   819  
   820  	rpURL, err := url.Parse(backendServer.URL)
   821  	if err != nil {
   822  		t.Fatal(err)
   823  	}
   824  
   825  	var proxyLog bytes.Buffer
   826  	rproxy := NewSingleHostReverseProxy(rpURL)
   827  	rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
   828  	donec := make(chan bool, 1)
   829  	frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   830  		defer func() { donec <- true }()
   831  		rproxy.ServeHTTP(w, r)
   832  	}))
   833  	defer frontendProxy.Close()
   834  
   835  	if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
   836  		t.Fatalf("want non-nil error")
   837  	}
   838  	// The race detector complains about the proxyLog usage in logf in copyBuffer
   839  	// and our usage below with proxyLog.Bytes() so we're explicitly using a
   840  	// channel to ensure that the ReverseProxy's ServeHTTP is done before we
   841  	// continue after Get.
   842  	<-donec
   843  
   844  	expected := []string{
   845  		"EOF",
   846  		"read",
   847  	}
   848  	for _, phrase := range expected {
   849  		if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
   850  			t.Errorf("expected log to contain phrase %q", phrase)
   851  		}
   852  	}
   853  }
   854  
   855  type staticTransport struct {
   856  	res *http.Response
   857  }
   858  
   859  func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
   860  	return t.res, nil
   861  }
   862  
   863  func BenchmarkServeHTTP(b *testing.B) {
   864  	res := &http.Response{
   865  		StatusCode: 200,
   866  		Body:       ioutil.NopCloser(strings.NewReader("")),
   867  	}
   868  	proxy := &ReverseProxy{
   869  		Director:  func(*http.Request) {},
   870  		Transport: &staticTransport{res},
   871  	}
   872  
   873  	w := httptest.NewRecorder()
   874  	r := httptest.NewRequest("GET", "/", nil)
   875  
   876  	b.ReportAllocs()
   877  	for i := 0; i < b.N; i++ {
   878  		proxy.ServeHTTP(w, r)
   879  	}
   880  }
   881  
   882  func TestServeHTTPDeepCopy(t *testing.T) {
   883  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   884  		w.Write([]byte("Hello Gopher!"))
   885  	}))
   886  	defer backend.Close()
   887  	backendURL, err := url.Parse(backend.URL)
   888  	if err != nil {
   889  		t.Fatal(err)
   890  	}
   891  
   892  	type result struct {
   893  		before, after string
   894  	}
   895  
   896  	resultChan := make(chan result, 1)
   897  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   898  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   899  		before := r.URL.String()
   900  		proxyHandler.ServeHTTP(w, r)
   901  		after := r.URL.String()
   902  		resultChan <- result{before: before, after: after}
   903  	}))
   904  	defer frontend.Close()
   905  
   906  	want := result{before: "/", after: "/"}
   907  
   908  	res, err := frontend.Client().Get(frontend.URL)
   909  	if err != nil {
   910  		t.Fatalf("Do: %v", err)
   911  	}
   912  	res.Body.Close()
   913  
   914  	got := <-resultChan
   915  	if got != want {
   916  		t.Errorf("got = %+v; want = %+v", got, want)
   917  	}
   918  }
   919  
   920  // Issue 18327: verify we always do a deep copy of the Request.Header map
   921  // before any mutations.
   922  func TestClonesRequestHeaders(t *testing.T) {
   923  	log.SetOutput(ioutil.Discard)
   924  	defer log.SetOutput(os.Stderr)
   925  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
   926  	req.RemoteAddr = "1.2.3.4:56789"
   927  	rp := &ReverseProxy{
   928  		Director: func(req *http.Request) {
   929  			req.Header.Set("From-Director", "1")
   930  		},
   931  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
   932  			if v := req.Header.Get("From-Director"); v != "1" {
   933  				t.Errorf("From-Directory value = %q; want 1", v)
   934  			}
   935  			return nil, io.EOF
   936  		}),
   937  	}
   938  	rp.ServeHTTP(httptest.NewRecorder(), req)
   939  
   940  	if req.Header.Get("From-Director") == "1" {
   941  		t.Error("Director header mutation modified caller's request")
   942  	}
   943  	if req.Header.Get("X-Forwarded-For") != "" {
   944  		t.Error("X-Forward-For header mutation modified caller's request")
   945  	}
   946  
   947  }
   948  
   949  type roundTripperFunc func(req *http.Request) (*http.Response, error)
   950  
   951  func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   952  	return fn(req)
   953  }
   954  
   955  func TestModifyResponseClosesBody(t *testing.T) {
   956  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
   957  	req.RemoteAddr = "1.2.3.4:56789"
   958  	closeCheck := new(checkCloser)
   959  	logBuf := new(bytes.Buffer)
   960  	outErr := errors.New("ModifyResponse error")
   961  	rp := &ReverseProxy{
   962  		Director: func(req *http.Request) {},
   963  		Transport: &staticTransport{&http.Response{
   964  			StatusCode: 200,
   965  			Body:       closeCheck,
   966  		}},
   967  		ErrorLog: log.New(logBuf, "", 0),
   968  		ModifyResponse: func(*http.Response) error {
   969  			return outErr
   970  		},
   971  	}
   972  	rec := httptest.NewRecorder()
   973  	rp.ServeHTTP(rec, req)
   974  	res := rec.Result()
   975  	if g, e := res.StatusCode, http.StatusBadGateway; g != e {
   976  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   977  	}
   978  	if !closeCheck.closed {
   979  		t.Errorf("body should have been closed")
   980  	}
   981  	if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
   982  		t.Errorf("ErrorLog %q does not contain %q", g, e)
   983  	}
   984  }
   985  
   986  type checkCloser struct {
   987  	closed bool
   988  }
   989  
   990  func (cc *checkCloser) Close() error {
   991  	cc.closed = true
   992  	return nil
   993  }
   994  
   995  func (cc *checkCloser) Read(b []byte) (int, error) {
   996  	return len(b), nil
   997  }
   998  
   999  // Issue 23643: panic on body copy error
  1000  func TestReverseProxy_PanicBodyError(t *testing.T) {
  1001  	log.SetOutput(ioutil.Discard)
  1002  	defer log.SetOutput(os.Stderr)
  1003  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1004  		out := "this call was relayed by the reverse proxy"
  1005  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1006  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1007  		fmt.Fprintln(w, out)
  1008  	}))
  1009  	defer backendServer.Close()
  1010  
  1011  	rpURL, err := url.Parse(backendServer.URL)
  1012  	if err != nil {
  1013  		t.Fatal(err)
  1014  	}
  1015  
  1016  	rproxy := NewSingleHostReverseProxy(rpURL)
  1017  
  1018  	// Ensure that the handler panics when the body read encounters an
  1019  	// io.ErrUnexpectedEOF
  1020  	defer func() {
  1021  		err := recover()
  1022  		if err == nil {
  1023  			t.Fatal("handler should have panicked")
  1024  		}
  1025  		if err != http.ErrAbortHandler {
  1026  			t.Fatal("expected ErrAbortHandler, got", err)
  1027  		}
  1028  	}()
  1029  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1030  	rproxy.ServeHTTP(httptest.NewRecorder(), req)
  1031  }
  1032  
  1033  func TestSelectFlushInterval(t *testing.T) {
  1034  	tests := []struct {
  1035  		name string
  1036  		p    *ReverseProxy
  1037  		req  *http.Request
  1038  		res  *http.Response
  1039  		want time.Duration
  1040  	}{
  1041  		{
  1042  			name: "default",
  1043  			res:  &http.Response{},
  1044  			p:    &ReverseProxy{FlushInterval: 123},
  1045  			want: 123,
  1046  		},
  1047  		{
  1048  			name: "server-sent events overrides non-zero",
  1049  			res: &http.Response{
  1050  				Header: http.Header{
  1051  					"Content-Type": {"text/event-stream"},
  1052  				},
  1053  			},
  1054  			p:    &ReverseProxy{FlushInterval: 123},
  1055  			want: -1,
  1056  		},
  1057  		{
  1058  			name: "server-sent events overrides zero",
  1059  			res: &http.Response{
  1060  				Header: http.Header{
  1061  					"Content-Type": {"text/event-stream"},
  1062  				},
  1063  			},
  1064  			p:    &ReverseProxy{FlushInterval: 0},
  1065  			want: -1,
  1066  		},
  1067  	}
  1068  	for _, tt := range tests {
  1069  		t.Run(tt.name, func(t *testing.T) {
  1070  			got := tt.p.flushInterval(tt.req, tt.res)
  1071  			if got != tt.want {
  1072  				t.Errorf("flushLatency = %v; want %v", got, tt.want)
  1073  			}
  1074  		})
  1075  	}
  1076  }
  1077  
  1078  func TestReverseProxyWebSocket(t *testing.T) {
  1079  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1080  		if upgradeType(r.Header) != "websocket" {
  1081  			t.Error("unexpected backend request")
  1082  			http.Error(w, "unexpected request", 400)
  1083  			return
  1084  		}
  1085  		c, _, err := w.(http.Hijacker).Hijack()
  1086  		if err != nil {
  1087  			t.Error(err)
  1088  			return
  1089  		}
  1090  		defer c.Close()
  1091  		io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
  1092  		bs := bufio.NewScanner(c)
  1093  		if !bs.Scan() {
  1094  			t.Errorf("backend failed to read line from client: %v", bs.Err())
  1095  			return
  1096  		}
  1097  		fmt.Fprintf(c, "backend got %q\n", bs.Text())
  1098  	}))
  1099  	defer backendServer.Close()
  1100  
  1101  	backURL, _ := url.Parse(backendServer.URL)
  1102  	rproxy := NewSingleHostReverseProxy(backURL)
  1103  	rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
  1104  	rproxy.ModifyResponse = func(res *http.Response) error {
  1105  		res.Header.Add("X-Modified", "true")
  1106  		return nil
  1107  	}
  1108  
  1109  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1110  		rw.Header().Set("X-Header", "X-Value")
  1111  		rproxy.ServeHTTP(rw, req)
  1112  	})
  1113  
  1114  	frontendProxy := httptest.NewServer(handler)
  1115  	defer frontendProxy.Close()
  1116  
  1117  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1118  	req.Header.Set("Connection", "Upgrade")
  1119  	req.Header.Set("Upgrade", "websocket")
  1120  
  1121  	c := frontendProxy.Client()
  1122  	res, err := c.Do(req)
  1123  	if err != nil {
  1124  		t.Fatal(err)
  1125  	}
  1126  	if res.StatusCode != 101 {
  1127  		t.Fatalf("status = %v; want 101", res.Status)
  1128  	}
  1129  
  1130  	got := res.Header.Get("X-Header")
  1131  	want := "X-Value"
  1132  	if got != want {
  1133  		t.Errorf("Header(XHeader) = %q; want %q", got, want)
  1134  	}
  1135  
  1136  	if upgradeType(res.Header) != "websocket" {
  1137  		t.Fatalf("not websocket upgrade; got %#v", res.Header)
  1138  	}
  1139  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1140  	if !ok {
  1141  		t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
  1142  	}
  1143  	defer rwc.Close()
  1144  
  1145  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1146  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1147  	}
  1148  
  1149  	io.WriteString(rwc, "Hello\n")
  1150  	bs := bufio.NewScanner(rwc)
  1151  	if !bs.Scan() {
  1152  		t.Fatalf("Scan: %v", bs.Err())
  1153  	}
  1154  	got = bs.Text()
  1155  	want = `backend got "Hello"`
  1156  	if got != want {
  1157  		t.Errorf("got %#q, want %#q", got, want)
  1158  	}
  1159  }
  1160  
  1161  func TestUnannouncedTrailer(t *testing.T) {
  1162  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1163  		w.WriteHeader(http.StatusOK)
  1164  		w.(http.Flusher).Flush()
  1165  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
  1166  	}))
  1167  	defer backend.Close()
  1168  	backendURL, err := url.Parse(backend.URL)
  1169  	if err != nil {
  1170  		t.Fatal(err)
  1171  	}
  1172  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1173  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
  1174  	frontend := httptest.NewServer(proxyHandler)
  1175  	defer frontend.Close()
  1176  	frontendClient := frontend.Client()
  1177  
  1178  	res, err := frontendClient.Get(frontend.URL)
  1179  	if err != nil {
  1180  		t.Fatalf("Get: %v", err)
  1181  	}
  1182  
  1183  	ioutil.ReadAll(res.Body)
  1184  
  1185  	if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
  1186  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
  1187  	}
  1188  
  1189  }
  1190  
  1191  func TestSingleJoinSlash(t *testing.T) {
  1192  	tests := []struct {
  1193  		slasha   string
  1194  		slashb   string
  1195  		expected string
  1196  	}{
  1197  		{"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1198  		{"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1199  		{"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
  1200  		{"https://www.google.com", "", "https://www.google.com/"},
  1201  		{"", "favicon.ico", "/favicon.ico"},
  1202  	}
  1203  	for _, tt := range tests {
  1204  		if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
  1205  			t.Errorf("singleJoiningSlash(%s,%s) want %s got %s",
  1206  				tt.slasha,
  1207  				tt.slashb,
  1208  				tt.expected,
  1209  				got)
  1210  		}
  1211  	}
  1212  }
  1213  

View as plain text