...
Run Format

Source file src/net/http/httputil/reverseproxy_test.go

Documentation: net/http/httputil

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"errors"
    13  	"fmt"
    14  	"io"
    15  	"io/ioutil"
    16  	"log"
    17  	"net/http"
    18  	"net/http/httptest"
    19  	"net/url"
    20  	"os"
    21  	"reflect"
    22  	"strconv"
    23  	"strings"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  )
    28  
    29  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    30  
    31  func init() {
    32  	inOurTests = true
    33  	hopHeaders = append(hopHeaders, fakeHopHeader)
    34  }
    35  
    36  func TestReverseProxy(t *testing.T) {
    37  	const backendResponse = "I am the backend"
    38  	const backendStatus = 404
    39  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    40  		if r.Method == "GET" && r.FormValue("mode") == "hangup" {
    41  			c, _, _ := w.(http.Hijacker).Hijack()
    42  			c.Close()
    43  			return
    44  		}
    45  		if len(r.TransferEncoding) > 0 {
    46  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    47  		}
    48  		if r.Header.Get("X-Forwarded-For") == "" {
    49  			t.Errorf("didn't get X-Forwarded-For header")
    50  		}
    51  		if c := r.Header.Get("Connection"); c != "" {
    52  			t.Errorf("handler got Connection header value %q", c)
    53  		}
    54  		if c := r.Header.Get("Te"); c != "trailers" {
    55  			t.Errorf("handler got Te header value %q; want 'trailers'", c)
    56  		}
    57  		if c := r.Header.Get("Upgrade"); c != "" {
    58  			t.Errorf("handler got Upgrade header value %q", c)
    59  		}
    60  		if c := r.Header.Get("Proxy-Connection"); c != "" {
    61  			t.Errorf("handler got Proxy-Connection header value %q", c)
    62  		}
    63  		if g, e := r.Host, "some-name"; g != e {
    64  			t.Errorf("backend got Host header %q, want %q", g, e)
    65  		}
    66  		w.Header().Set("Trailers", "not a special header field name")
    67  		w.Header().Set("Trailer", "X-Trailer")
    68  		w.Header().Set("X-Foo", "bar")
    69  		w.Header().Set("Upgrade", "foo")
    70  		w.Header().Set(fakeHopHeader, "foo")
    71  		w.Header().Add("X-Multi-Value", "foo")
    72  		w.Header().Add("X-Multi-Value", "bar")
    73  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    74  		w.WriteHeader(backendStatus)
    75  		w.Write([]byte(backendResponse))
    76  		w.Header().Set("X-Trailer", "trailer_value")
    77  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
    78  	}))
    79  	defer backend.Close()
    80  	backendURL, err := url.Parse(backend.URL)
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    85  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
    86  	frontend := httptest.NewServer(proxyHandler)
    87  	defer frontend.Close()
    88  	frontendClient := frontend.Client()
    89  
    90  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    91  	getReq.Host = "some-name"
    92  	getReq.Header.Set("Connection", "close")
    93  	getReq.Header.Set("Te", "trailers")
    94  	getReq.Header.Set("Proxy-Connection", "should be deleted")
    95  	getReq.Header.Set("Upgrade", "foo")
    96  	getReq.Close = true
    97  	res, err := frontendClient.Do(getReq)
    98  	if err != nil {
    99  		t.Fatalf("Get: %v", err)
   100  	}
   101  	if g, e := res.StatusCode, backendStatus; g != e {
   102  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   103  	}
   104  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
   105  		t.Errorf("got X-Foo %q; expected %q", g, e)
   106  	}
   107  	if c := res.Header.Get(fakeHopHeader); c != "" {
   108  		t.Errorf("got %s header value %q", fakeHopHeader, c)
   109  	}
   110  	if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
   111  		t.Errorf("header Trailers = %q; want %q", g, e)
   112  	}
   113  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
   114  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
   115  	}
   116  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
   117  		t.Fatalf("got %d SetCookies, want %d", g, e)
   118  	}
   119  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
   120  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
   121  	}
   122  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
   123  		t.Errorf("unexpected cookie %q", cookie.Name)
   124  	}
   125  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   126  	if g, e := string(bodyBytes), backendResponse; g != e {
   127  		t.Errorf("got body %q; expected %q", g, e)
   128  	}
   129  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   130  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   131  	}
   132  	if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
   133  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
   134  	}
   135  
   136  	// Test that a backend failing to be reached or one which doesn't return
   137  	// a response results in a StatusBadGateway.
   138  	getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
   139  	getReq.Close = true
   140  	res, err = frontendClient.Do(getReq)
   141  	if err != nil {
   142  		t.Fatal(err)
   143  	}
   144  	res.Body.Close()
   145  	if res.StatusCode != http.StatusBadGateway {
   146  		t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
   147  	}
   148  
   149  }
   150  
   151  // Issue 16875: remove any proxied headers mentioned in the "Connection"
   152  // header value.
   153  func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
   154  	const fakeConnectionToken = "X-Fake-Connection-Token"
   155  	const backendResponse = "I am the backend"
   156  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   157  		if c := r.Header.Get(fakeConnectionToken); c != "" {
   158  			t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   159  		}
   160  		if c := r.Header.Get("Upgrade"); c != "" {
   161  			t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
   162  		}
   163  		w.Header().Set("Connection", "Upgrade, "+fakeConnectionToken)
   164  		w.Header().Set("Upgrade", "should be deleted")
   165  		w.Header().Set(fakeConnectionToken, "should be deleted")
   166  		io.WriteString(w, backendResponse)
   167  	}))
   168  	defer backend.Close()
   169  	backendURL, err := url.Parse(backend.URL)
   170  	if err != nil {
   171  		t.Fatal(err)
   172  	}
   173  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   174  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   175  		proxyHandler.ServeHTTP(w, r)
   176  		if c := r.Header.Get("Upgrade"); c != "original value" {
   177  			t.Errorf("handler modified header %q = %q; want %q", "Upgrade", c, "original value")
   178  		}
   179  	}))
   180  	defer frontend.Close()
   181  
   182  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   183  	getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken)
   184  	getReq.Header.Set("Upgrade", "original value")
   185  	getReq.Header.Set(fakeConnectionToken, "should be deleted")
   186  	res, err := frontend.Client().Do(getReq)
   187  	if err != nil {
   188  		t.Fatalf("Get: %v", err)
   189  	}
   190  	defer res.Body.Close()
   191  	bodyBytes, err := ioutil.ReadAll(res.Body)
   192  	if err != nil {
   193  		t.Fatalf("reading body: %v", err)
   194  	}
   195  	if got, want := string(bodyBytes), backendResponse; got != want {
   196  		t.Errorf("got body %q; want %q", got, want)
   197  	}
   198  	if c := res.Header.Get("Upgrade"); c != "" {
   199  		t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
   200  	}
   201  	if c := res.Header.Get(fakeConnectionToken); c != "" {
   202  		t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   203  	}
   204  }
   205  
   206  func TestXForwardedFor(t *testing.T) {
   207  	const prevForwardedFor = "client ip"
   208  	const backendResponse = "I am the backend"
   209  	const backendStatus = 404
   210  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   211  		if r.Header.Get("X-Forwarded-For") == "" {
   212  			t.Errorf("didn't get X-Forwarded-For header")
   213  		}
   214  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   215  			t.Errorf("X-Forwarded-For didn't contain prior data")
   216  		}
   217  		w.WriteHeader(backendStatus)
   218  		w.Write([]byte(backendResponse))
   219  	}))
   220  	defer backend.Close()
   221  	backendURL, err := url.Parse(backend.URL)
   222  	if err != nil {
   223  		t.Fatal(err)
   224  	}
   225  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   226  	frontend := httptest.NewServer(proxyHandler)
   227  	defer frontend.Close()
   228  
   229  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   230  	getReq.Host = "some-name"
   231  	getReq.Header.Set("Connection", "close")
   232  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   233  	getReq.Close = true
   234  	res, err := frontend.Client().Do(getReq)
   235  	if err != nil {
   236  		t.Fatalf("Get: %v", err)
   237  	}
   238  	if g, e := res.StatusCode, backendStatus; g != e {
   239  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   240  	}
   241  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   242  	if g, e := string(bodyBytes), backendResponse; g != e {
   243  		t.Errorf("got body %q; expected %q", g, e)
   244  	}
   245  }
   246  
   247  var proxyQueryTests = []struct {
   248  	baseSuffix string // suffix to add to backend URL
   249  	reqSuffix  string // suffix to add to frontend's request URL
   250  	want       string // what backend should see for final request URL (without ?)
   251  }{
   252  	{"", "", ""},
   253  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   254  	{"", "?us=er", "us=er"},
   255  	{"?sta=tic", "", "sta=tic"},
   256  }
   257  
   258  func TestReverseProxyQuery(t *testing.T) {
   259  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   260  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   261  		w.Write([]byte("hi"))
   262  	}))
   263  	defer backend.Close()
   264  
   265  	for i, tt := range proxyQueryTests {
   266  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   267  		if err != nil {
   268  			t.Fatal(err)
   269  		}
   270  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   271  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   272  		req.Close = true
   273  		res, err := frontend.Client().Do(req)
   274  		if err != nil {
   275  			t.Fatalf("%d. Get: %v", i, err)
   276  		}
   277  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   278  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   279  		}
   280  		res.Body.Close()
   281  		frontend.Close()
   282  	}
   283  }
   284  
   285  func TestReverseProxyFlushInterval(t *testing.T) {
   286  	const expected = "hi"
   287  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   288  		w.Write([]byte(expected))
   289  	}))
   290  	defer backend.Close()
   291  
   292  	backendURL, err := url.Parse(backend.URL)
   293  	if err != nil {
   294  		t.Fatal(err)
   295  	}
   296  
   297  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   298  	proxyHandler.FlushInterval = time.Microsecond
   299  
   300  	done := make(chan bool)
   301  	onExitFlushLoop = func() { done <- true }
   302  	defer func() { onExitFlushLoop = nil }()
   303  
   304  	frontend := httptest.NewServer(proxyHandler)
   305  	defer frontend.Close()
   306  
   307  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   308  	req.Close = true
   309  	res, err := frontend.Client().Do(req)
   310  	if err != nil {
   311  		t.Fatalf("Get: %v", err)
   312  	}
   313  	defer res.Body.Close()
   314  	if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
   315  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   316  	}
   317  
   318  	select {
   319  	case <-done:
   320  		// OK
   321  	case <-time.After(5 * time.Second):
   322  		t.Error("maxLatencyWriter flushLoop() never exited")
   323  	}
   324  }
   325  
   326  func TestReverseProxyCancelation(t *testing.T) {
   327  	const backendResponse = "I am the backend"
   328  
   329  	reqInFlight := make(chan struct{})
   330  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   331  		close(reqInFlight) // cause the client to cancel its request
   332  
   333  		select {
   334  		case <-time.After(10 * time.Second):
   335  			// Note: this should only happen in broken implementations, and the
   336  			// closenotify case should be instantaneous.
   337  			t.Error("Handler never saw CloseNotify")
   338  			return
   339  		case <-w.(http.CloseNotifier).CloseNotify():
   340  		}
   341  
   342  		w.WriteHeader(http.StatusOK)
   343  		w.Write([]byte(backendResponse))
   344  	}))
   345  
   346  	defer backend.Close()
   347  
   348  	backend.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
   349  
   350  	backendURL, err := url.Parse(backend.URL)
   351  	if err != nil {
   352  		t.Fatal(err)
   353  	}
   354  
   355  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   356  
   357  	// Discards errors of the form:
   358  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   359  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0)
   360  
   361  	frontend := httptest.NewServer(proxyHandler)
   362  	defer frontend.Close()
   363  	frontendClient := frontend.Client()
   364  
   365  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   366  	go func() {
   367  		<-reqInFlight
   368  		frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
   369  	}()
   370  	res, err := frontendClient.Do(getReq)
   371  	if res != nil {
   372  		t.Errorf("got response %v; want nil", res.Status)
   373  	}
   374  	if err == nil {
   375  		// This should be an error like:
   376  		// Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079:
   377  		//    use of closed network connection
   378  		t.Error("Server.Client().Do() returned nil error; want non-nil error")
   379  	}
   380  }
   381  
   382  func req(t *testing.T, v string) *http.Request {
   383  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
   384  	if err != nil {
   385  		t.Fatal(err)
   386  	}
   387  	return req
   388  }
   389  
   390  // Issue 12344
   391  func TestNilBody(t *testing.T) {
   392  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   393  		w.Write([]byte("hi"))
   394  	}))
   395  	defer backend.Close()
   396  
   397  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   398  		backURL, _ := url.Parse(backend.URL)
   399  		rp := NewSingleHostReverseProxy(backURL)
   400  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
   401  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
   402  		rp.ServeHTTP(w, r)
   403  	}))
   404  	defer frontend.Close()
   405  
   406  	res, err := http.Get(frontend.URL)
   407  	if err != nil {
   408  		t.Fatal(err)
   409  	}
   410  	defer res.Body.Close()
   411  	slurp, err := ioutil.ReadAll(res.Body)
   412  	if err != nil {
   413  		t.Fatal(err)
   414  	}
   415  	if string(slurp) != "hi" {
   416  		t.Errorf("Got %q; want %q", slurp, "hi")
   417  	}
   418  }
   419  
   420  // Issue 15524
   421  func TestUserAgentHeader(t *testing.T) {
   422  	const explicitUA = "explicit UA"
   423  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   424  		if r.URL.Path == "/noua" {
   425  			if c := r.Header.Get("User-Agent"); c != "" {
   426  				t.Errorf("handler got non-empty User-Agent header %q", c)
   427  			}
   428  			return
   429  		}
   430  		if c := r.Header.Get("User-Agent"); c != explicitUA {
   431  			t.Errorf("handler got unexpected User-Agent header %q", c)
   432  		}
   433  	}))
   434  	defer backend.Close()
   435  	backendURL, err := url.Parse(backend.URL)
   436  	if err != nil {
   437  		t.Fatal(err)
   438  	}
   439  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   440  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   441  	frontend := httptest.NewServer(proxyHandler)
   442  	defer frontend.Close()
   443  	frontendClient := frontend.Client()
   444  
   445  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   446  	getReq.Header.Set("User-Agent", explicitUA)
   447  	getReq.Close = true
   448  	res, err := frontendClient.Do(getReq)
   449  	if err != nil {
   450  		t.Fatalf("Get: %v", err)
   451  	}
   452  	res.Body.Close()
   453  
   454  	getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
   455  	getReq.Header.Set("User-Agent", "")
   456  	getReq.Close = true
   457  	res, err = frontendClient.Do(getReq)
   458  	if err != nil {
   459  		t.Fatalf("Get: %v", err)
   460  	}
   461  	res.Body.Close()
   462  }
   463  
   464  type bufferPool struct {
   465  	get func() []byte
   466  	put func([]byte)
   467  }
   468  
   469  func (bp bufferPool) Get() []byte  { return bp.get() }
   470  func (bp bufferPool) Put(v []byte) { bp.put(v) }
   471  
   472  func TestReverseProxyGetPutBuffer(t *testing.T) {
   473  	const msg = "hi"
   474  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   475  		io.WriteString(w, msg)
   476  	}))
   477  	defer backend.Close()
   478  
   479  	backendURL, err := url.Parse(backend.URL)
   480  	if err != nil {
   481  		t.Fatal(err)
   482  	}
   483  
   484  	var (
   485  		mu  sync.Mutex
   486  		log []string
   487  	)
   488  	addLog := func(event string) {
   489  		mu.Lock()
   490  		defer mu.Unlock()
   491  		log = append(log, event)
   492  	}
   493  	rp := NewSingleHostReverseProxy(backendURL)
   494  	const size = 1234
   495  	rp.BufferPool = bufferPool{
   496  		get: func() []byte {
   497  			addLog("getBuf")
   498  			return make([]byte, size)
   499  		},
   500  		put: func(p []byte) {
   501  			addLog("putBuf-" + strconv.Itoa(len(p)))
   502  		},
   503  	}
   504  	frontend := httptest.NewServer(rp)
   505  	defer frontend.Close()
   506  
   507  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   508  	req.Close = true
   509  	res, err := frontend.Client().Do(req)
   510  	if err != nil {
   511  		t.Fatalf("Get: %v", err)
   512  	}
   513  	slurp, err := ioutil.ReadAll(res.Body)
   514  	res.Body.Close()
   515  	if err != nil {
   516  		t.Fatalf("reading body: %v", err)
   517  	}
   518  	if string(slurp) != msg {
   519  		t.Errorf("msg = %q; want %q", slurp, msg)
   520  	}
   521  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
   522  	mu.Lock()
   523  	defer mu.Unlock()
   524  	if !reflect.DeepEqual(log, wantLog) {
   525  		t.Errorf("Log events = %q; want %q", log, wantLog)
   526  	}
   527  }
   528  
   529  func TestReverseProxy_Post(t *testing.T) {
   530  	const backendResponse = "I am the backend"
   531  	const backendStatus = 200
   532  	var requestBody = bytes.Repeat([]byte("a"), 1<<20)
   533  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   534  		slurp, err := ioutil.ReadAll(r.Body)
   535  		if err != nil {
   536  			t.Errorf("Backend body read = %v", err)
   537  		}
   538  		if len(slurp) != len(requestBody) {
   539  			t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
   540  		}
   541  		if !bytes.Equal(slurp, requestBody) {
   542  			t.Error("Backend read wrong request body.") // 1MB; omitting details
   543  		}
   544  		w.Write([]byte(backendResponse))
   545  	}))
   546  	defer backend.Close()
   547  	backendURL, err := url.Parse(backend.URL)
   548  	if err != nil {
   549  		t.Fatal(err)
   550  	}
   551  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   552  	frontend := httptest.NewServer(proxyHandler)
   553  	defer frontend.Close()
   554  
   555  	postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
   556  	res, err := frontend.Client().Do(postReq)
   557  	if err != nil {
   558  		t.Fatalf("Do: %v", err)
   559  	}
   560  	if g, e := res.StatusCode, backendStatus; g != e {
   561  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   562  	}
   563  	bodyBytes, _ := ioutil.ReadAll(res.Body)
   564  	if g, e := string(bodyBytes), backendResponse; g != e {
   565  		t.Errorf("got body %q; expected %q", g, e)
   566  	}
   567  }
   568  
   569  type RoundTripperFunc func(*http.Request) (*http.Response, error)
   570  
   571  func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   572  	return fn(req)
   573  }
   574  
   575  // Issue 16036: send a Request with a nil Body when possible
   576  func TestReverseProxy_NilBody(t *testing.T) {
   577  	backendURL, _ := url.Parse("http://fake.tld/")
   578  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   579  	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   580  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   581  		if req.Body != nil {
   582  			t.Error("Body != nil; want a nil Body")
   583  		}
   584  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   585  	})
   586  	frontend := httptest.NewServer(proxyHandler)
   587  	defer frontend.Close()
   588  
   589  	res, err := frontend.Client().Get(frontend.URL)
   590  	if err != nil {
   591  		t.Fatal(err)
   592  	}
   593  	defer res.Body.Close()
   594  	if res.StatusCode != 502 {
   595  		t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
   596  	}
   597  }
   598  
   599  // Issue 14237. Test ModifyResponse and that an error from it
   600  // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
   601  func TestReverseProxyModifyResponse(t *testing.T) {
   602  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   603  		w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
   604  	}))
   605  	defer backendServer.Close()
   606  
   607  	rpURL, _ := url.Parse(backendServer.URL)
   608  	rproxy := NewSingleHostReverseProxy(rpURL)
   609  	rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   610  	rproxy.ModifyResponse = func(resp *http.Response) error {
   611  		if resp.Header.Get("X-Hit-Mod") != "true" {
   612  			return fmt.Errorf("tried to by-pass proxy")
   613  		}
   614  		return nil
   615  	}
   616  
   617  	frontendProxy := httptest.NewServer(rproxy)
   618  	defer frontendProxy.Close()
   619  
   620  	tests := []struct {
   621  		url      string
   622  		wantCode int
   623  	}{
   624  		{frontendProxy.URL + "/mod", http.StatusOK},
   625  		{frontendProxy.URL + "/schedule", http.StatusBadGateway},
   626  	}
   627  
   628  	for i, tt := range tests {
   629  		resp, err := http.Get(tt.url)
   630  		if err != nil {
   631  			t.Fatalf("failed to reach proxy: %v", err)
   632  		}
   633  		if g, e := resp.StatusCode, tt.wantCode; g != e {
   634  			t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
   635  		}
   636  		resp.Body.Close()
   637  	}
   638  }
   639  
   640  type failingRoundTripper struct{}
   641  
   642  func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   643  	return nil, errors.New("some error")
   644  }
   645  
   646  type staticResponseRoundTripper struct{ res *http.Response }
   647  
   648  func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   649  	return rt.res, nil
   650  }
   651  
   652  func TestReverseProxyErrorHandler(t *testing.T) {
   653  	tests := []struct {
   654  		name           string
   655  		wantCode       int
   656  		errorHandler   func(http.ResponseWriter, *http.Request, error)
   657  		transport      http.RoundTripper // defaults to failingRoundTripper
   658  		modifyResponse func(*http.Response) error
   659  	}{
   660  		{
   661  			name:     "default",
   662  			wantCode: http.StatusBadGateway,
   663  		},
   664  		{
   665  			name:         "errorhandler",
   666  			wantCode:     http.StatusTeapot,
   667  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   668  		},
   669  		{
   670  			name: "modifyresponse_noerr",
   671  			transport: staticResponseRoundTripper{
   672  				&http.Response{StatusCode: 345, Body: http.NoBody},
   673  			},
   674  			modifyResponse: func(res *http.Response) error {
   675  				res.StatusCode++
   676  				return nil
   677  			},
   678  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   679  			wantCode:     346,
   680  		},
   681  		{
   682  			name: "modifyresponse_err",
   683  			transport: staticResponseRoundTripper{
   684  				&http.Response{StatusCode: 345, Body: http.NoBody},
   685  			},
   686  			modifyResponse: func(res *http.Response) error {
   687  				res.StatusCode++
   688  				return errors.New("some error to trigger errorHandler")
   689  			},
   690  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   691  			wantCode:     http.StatusTeapot,
   692  		},
   693  	}
   694  
   695  	for _, tt := range tests {
   696  		t.Run(tt.name, func(t *testing.T) {
   697  			target := &url.URL{
   698  				Scheme: "http",
   699  				Host:   "dummy.tld",
   700  				Path:   "/",
   701  			}
   702  			rproxy := NewSingleHostReverseProxy(target)
   703  			rproxy.Transport = tt.transport
   704  			rproxy.ModifyResponse = tt.modifyResponse
   705  			if rproxy.Transport == nil {
   706  				rproxy.Transport = failingRoundTripper{}
   707  			}
   708  			rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
   709  			if tt.errorHandler != nil {
   710  				rproxy.ErrorHandler = tt.errorHandler
   711  			}
   712  			frontendProxy := httptest.NewServer(rproxy)
   713  			defer frontendProxy.Close()
   714  
   715  			resp, err := http.Get(frontendProxy.URL + "/test")
   716  			if err != nil {
   717  				t.Fatalf("failed to reach proxy: %v", err)
   718  			}
   719  			if g, e := resp.StatusCode, tt.wantCode; g != e {
   720  				t.Errorf("got res.StatusCode %d; expected %d", g, e)
   721  			}
   722  			resp.Body.Close()
   723  		})
   724  	}
   725  }
   726  
   727  // Issue 16659: log errors from short read
   728  func TestReverseProxy_CopyBuffer(t *testing.T) {
   729  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   730  		out := "this call was relayed by the reverse proxy"
   731  		// Coerce a wrong content length to induce io.UnexpectedEOF
   732  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
   733  		fmt.Fprintln(w, out)
   734  	}))
   735  	defer backendServer.Close()
   736  
   737  	rpURL, err := url.Parse(backendServer.URL)
   738  	if err != nil {
   739  		t.Fatal(err)
   740  	}
   741  
   742  	var proxyLog bytes.Buffer
   743  	rproxy := NewSingleHostReverseProxy(rpURL)
   744  	rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
   745  	donec := make(chan bool, 1)
   746  	frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   747  		defer func() { donec <- true }()
   748  		rproxy.ServeHTTP(w, r)
   749  	}))
   750  	defer frontendProxy.Close()
   751  
   752  	if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
   753  		t.Fatalf("want non-nil error")
   754  	}
   755  	// The race detector complains about the proxyLog usage in logf in copyBuffer
   756  	// and our usage below with proxyLog.Bytes() so we're explicitly using a
   757  	// channel to ensure that the ReverseProxy's ServeHTTP is done before we
   758  	// continue after Get.
   759  	<-donec
   760  
   761  	expected := []string{
   762  		"EOF",
   763  		"read",
   764  	}
   765  	for _, phrase := range expected {
   766  		if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
   767  			t.Errorf("expected log to contain phrase %q", phrase)
   768  		}
   769  	}
   770  }
   771  
   772  type staticTransport struct {
   773  	res *http.Response
   774  }
   775  
   776  func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
   777  	return t.res, nil
   778  }
   779  
   780  func BenchmarkServeHTTP(b *testing.B) {
   781  	res := &http.Response{
   782  		StatusCode: 200,
   783  		Body:       ioutil.NopCloser(strings.NewReader("")),
   784  	}
   785  	proxy := &ReverseProxy{
   786  		Director:  func(*http.Request) {},
   787  		Transport: &staticTransport{res},
   788  	}
   789  
   790  	w := httptest.NewRecorder()
   791  	r := httptest.NewRequest("GET", "/", nil)
   792  
   793  	b.ReportAllocs()
   794  	for i := 0; i < b.N; i++ {
   795  		proxy.ServeHTTP(w, r)
   796  	}
   797  }
   798  
   799  func TestServeHTTPDeepCopy(t *testing.T) {
   800  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   801  		w.Write([]byte("Hello Gopher!"))
   802  	}))
   803  	defer backend.Close()
   804  	backendURL, err := url.Parse(backend.URL)
   805  	if err != nil {
   806  		t.Fatal(err)
   807  	}
   808  
   809  	type result struct {
   810  		before, after string
   811  	}
   812  
   813  	resultChan := make(chan result, 1)
   814  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   815  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   816  		before := r.URL.String()
   817  		proxyHandler.ServeHTTP(w, r)
   818  		after := r.URL.String()
   819  		resultChan <- result{before: before, after: after}
   820  	}))
   821  	defer frontend.Close()
   822  
   823  	want := result{before: "/", after: "/"}
   824  
   825  	res, err := frontend.Client().Get(frontend.URL)
   826  	if err != nil {
   827  		t.Fatalf("Do: %v", err)
   828  	}
   829  	res.Body.Close()
   830  
   831  	got := <-resultChan
   832  	if got != want {
   833  		t.Errorf("got = %+v; want = %+v", got, want)
   834  	}
   835  }
   836  
   837  // Issue 18327: verify we always do a deep copy of the Request.Header map
   838  // before any mutations.
   839  func TestClonesRequestHeaders(t *testing.T) {
   840  	log.SetOutput(ioutil.Discard)
   841  	defer log.SetOutput(os.Stderr)
   842  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
   843  	req.RemoteAddr = "1.2.3.4:56789"
   844  	rp := &ReverseProxy{
   845  		Director: func(req *http.Request) {
   846  			req.Header.Set("From-Director", "1")
   847  		},
   848  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
   849  			if v := req.Header.Get("From-Director"); v != "1" {
   850  				t.Errorf("From-Directory value = %q; want 1", v)
   851  			}
   852  			return nil, io.EOF
   853  		}),
   854  	}
   855  	rp.ServeHTTP(httptest.NewRecorder(), req)
   856  
   857  	if req.Header.Get("From-Director") == "1" {
   858  		t.Error("Director header mutation modified caller's request")
   859  	}
   860  	if req.Header.Get("X-Forwarded-For") != "" {
   861  		t.Error("X-Forward-For header mutation modified caller's request")
   862  	}
   863  
   864  }
   865  
   866  type roundTripperFunc func(req *http.Request) (*http.Response, error)
   867  
   868  func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   869  	return fn(req)
   870  }
   871  
   872  func TestModifyResponseClosesBody(t *testing.T) {
   873  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
   874  	req.RemoteAddr = "1.2.3.4:56789"
   875  	closeCheck := new(checkCloser)
   876  	logBuf := new(bytes.Buffer)
   877  	outErr := errors.New("ModifyResponse error")
   878  	rp := &ReverseProxy{
   879  		Director: func(req *http.Request) {},
   880  		Transport: &staticTransport{&http.Response{
   881  			StatusCode: 200,
   882  			Body:       closeCheck,
   883  		}},
   884  		ErrorLog: log.New(logBuf, "", 0),
   885  		ModifyResponse: func(*http.Response) error {
   886  			return outErr
   887  		},
   888  	}
   889  	rec := httptest.NewRecorder()
   890  	rp.ServeHTTP(rec, req)
   891  	res := rec.Result()
   892  	if g, e := res.StatusCode, http.StatusBadGateway; g != e {
   893  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   894  	}
   895  	if !closeCheck.closed {
   896  		t.Errorf("body should have been closed")
   897  	}
   898  	if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
   899  		t.Errorf("ErrorLog %q does not contain %q", g, e)
   900  	}
   901  }
   902  
   903  type checkCloser struct {
   904  	closed bool
   905  }
   906  
   907  func (cc *checkCloser) Close() error {
   908  	cc.closed = true
   909  	return nil
   910  }
   911  
   912  func (cc *checkCloser) Read(b []byte) (int, error) {
   913  	return len(b), nil
   914  }
   915  
   916  // Issue 23643: panic on body copy error
   917  func TestReverseProxy_PanicBodyError(t *testing.T) {
   918  	log.SetOutput(ioutil.Discard)
   919  	defer log.SetOutput(os.Stderr)
   920  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   921  		out := "this call was relayed by the reverse proxy"
   922  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
   923  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
   924  		fmt.Fprintln(w, out)
   925  	}))
   926  	defer backendServer.Close()
   927  
   928  	rpURL, err := url.Parse(backendServer.URL)
   929  	if err != nil {
   930  		t.Fatal(err)
   931  	}
   932  
   933  	rproxy := NewSingleHostReverseProxy(rpURL)
   934  
   935  	// Ensure that the handler panics when the body read encounters an
   936  	// io.ErrUnexpectedEOF
   937  	defer func() {
   938  		err := recover()
   939  		if err == nil {
   940  			t.Fatal("handler should have panicked")
   941  		}
   942  		if err != http.ErrAbortHandler {
   943  			t.Fatal("expected ErrAbortHandler, got", err)
   944  		}
   945  	}()
   946  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
   947  	rproxy.ServeHTTP(httptest.NewRecorder(), req)
   948  }
   949  

View as plain text