1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Reverse proxy tests. 6 7 package httputil 8 9 import ( 10 "bufio" 11 "bytes" 12 "context" 13 "errors" 14 "fmt" 15 "io" 16 "io/ioutil" 17 "log" 18 "net/http" 19 "net/http/httptest" 20 "net/url" 21 "os" 22 "reflect" 23 "sort" 24 "strconv" 25 "strings" 26 "sync" 27 "testing" 28 "time" 29 ) 30 31 const fakeHopHeader = "X-Fake-Hop-Header-For-Test" 32 33 func init() { 34 inOurTests = true 35 hopHeaders = append(hopHeaders, fakeHopHeader) 36 } 37 38 func TestReverseProxy(t *testing.T) { 39 const backendResponse = "I am the backend" 40 const backendStatus = 404 41 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 42 if r.Method == "GET" && r.FormValue("mode") == "hangup" { 43 c, _, _ := w.(http.Hijacker).Hijack() 44 c.Close() 45 return 46 } 47 if len(r.TransferEncoding) > 0 { 48 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) 49 } 50 if r.Header.Get("X-Forwarded-For") == "" { 51 t.Errorf("didn't get X-Forwarded-For header") 52 } 53 if c := r.Header.Get("Connection"); c != "" { 54 t.Errorf("handler got Connection header value %q", c) 55 } 56 if c := r.Header.Get("Te"); c != "trailers" { 57 t.Errorf("handler got Te header value %q; want 'trailers'", c) 58 } 59 if c := r.Header.Get("Upgrade"); c != "" { 60 t.Errorf("handler got Upgrade header value %q", c) 61 } 62 if c := r.Header.Get("Proxy-Connection"); c != "" { 63 t.Errorf("handler got Proxy-Connection header value %q", c) 64 } 65 if g, e := r.Host, "some-name"; g != e { 66 t.Errorf("backend got Host header %q, want %q", g, e) 67 } 68 w.Header().Set("Trailers", "not a special header field name") 69 w.Header().Set("Trailer", "X-Trailer") 70 w.Header().Set("X-Foo", "bar") 71 w.Header().Set("Upgrade", "foo") 72 w.Header().Set(fakeHopHeader, "foo") 73 w.Header().Add("X-Multi-Value", "foo") 74 w.Header().Add("X-Multi-Value", "bar") 75 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) 76 w.WriteHeader(backendStatus) 77 w.Write([]byte(backendResponse)) 78 w.Header().Set("X-Trailer", "trailer_value") 79 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 80 })) 81 defer backend.Close() 82 backendURL, err := url.Parse(backend.URL) 83 if err != nil { 84 t.Fatal(err) 85 } 86 proxyHandler := NewSingleHostReverseProxy(backendURL) 87 proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests 88 frontend := httptest.NewServer(proxyHandler) 89 defer frontend.Close() 90 frontendClient := frontend.Client() 91 92 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 93 getReq.Host = "some-name" 94 getReq.Header.Set("Connection", "close") 95 getReq.Header.Set("Te", "trailers") 96 getReq.Header.Set("Proxy-Connection", "should be deleted") 97 getReq.Header.Set("Upgrade", "foo") 98 getReq.Close = true 99 res, err := frontendClient.Do(getReq) 100 if err != nil { 101 t.Fatalf("Get: %v", err) 102 } 103 if g, e := res.StatusCode, backendStatus; g != e { 104 t.Errorf("got res.StatusCode %d; expected %d", g, e) 105 } 106 if g, e := res.Header.Get("X-Foo"), "bar"; g != e { 107 t.Errorf("got X-Foo %q; expected %q", g, e) 108 } 109 if c := res.Header.Get(fakeHopHeader); c != "" { 110 t.Errorf("got %s header value %q", fakeHopHeader, c) 111 } 112 if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { 113 t.Errorf("header Trailers = %q; want %q", g, e) 114 } 115 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { 116 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) 117 } 118 if g, e := len(res.Header["Set-Cookie"]), 1; g != e { 119 t.Fatalf("got %d SetCookies, want %d", g, e) 120 } 121 if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { 122 t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) 123 } 124 if cookie := res.Cookies()[0]; cookie.Name != "flavor" { 125 t.Errorf("unexpected cookie %q", cookie.Name) 126 } 127 bodyBytes, _ := ioutil.ReadAll(res.Body) 128 if g, e := string(bodyBytes), backendResponse; g != e { 129 t.Errorf("got body %q; expected %q", g, e) 130 } 131 if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { 132 t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) 133 } 134 if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { 135 t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) 136 } 137 138 // Test that a backend failing to be reached or one which doesn't return 139 // a response results in a StatusBadGateway. 140 getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) 141 getReq.Close = true 142 res, err = frontendClient.Do(getReq) 143 if err != nil { 144 t.Fatal(err) 145 } 146 res.Body.Close() 147 if res.StatusCode != http.StatusBadGateway { 148 t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) 149 } 150 151 } 152 153 // Issue 16875: remove any proxied headers mentioned in the "Connection" 154 // header value. 155 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { 156 const fakeConnectionToken = "X-Fake-Connection-Token" 157 const backendResponse = "I am the backend" 158 159 // someConnHeader is some arbitrary header to be declared as a hop-by-hop header 160 // in the Request's Connection header. 161 const someConnHeader = "X-Some-Conn-Header" 162 163 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 164 if c := r.Header.Get("Connection"); c != "" { 165 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 166 } 167 if c := r.Header.Get(fakeConnectionToken); c != "" { 168 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 169 } 170 if c := r.Header.Get(someConnHeader); c != "" { 171 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 172 } 173 w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) 174 w.Header().Add("Connection", someConnHeader) 175 w.Header().Set(someConnHeader, "should be deleted") 176 w.Header().Set(fakeConnectionToken, "should be deleted") 177 io.WriteString(w, backendResponse) 178 })) 179 defer backend.Close() 180 backendURL, err := url.Parse(backend.URL) 181 if err != nil { 182 t.Fatal(err) 183 } 184 proxyHandler := NewSingleHostReverseProxy(backendURL) 185 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 186 proxyHandler.ServeHTTP(w, r) 187 if c := r.Header.Get(someConnHeader); c != "should be deleted" { 188 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") 189 } 190 if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { 191 t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") 192 } 193 c := r.Header["Connection"] 194 var cf []string 195 for _, f := range c { 196 for _, sf := range strings.Split(f, ",") { 197 if sf = strings.TrimSpace(sf); sf != "" { 198 cf = append(cf, sf) 199 } 200 } 201 } 202 sort.Strings(cf) 203 expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} 204 sort.Strings(expectedValues) 205 if !reflect.DeepEqual(cf, expectedValues) { 206 t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) 207 } 208 })) 209 defer frontend.Close() 210 211 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 212 getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) 213 getReq.Header.Add("Connection", someConnHeader) 214 getReq.Header.Set(someConnHeader, "should be deleted") 215 getReq.Header.Set(fakeConnectionToken, "should be deleted") 216 res, err := frontend.Client().Do(getReq) 217 if err != nil { 218 t.Fatalf("Get: %v", err) 219 } 220 defer res.Body.Close() 221 bodyBytes, err := ioutil.ReadAll(res.Body) 222 if err != nil { 223 t.Fatalf("reading body: %v", err) 224 } 225 if got, want := string(bodyBytes), backendResponse; got != want { 226 t.Errorf("got body %q; want %q", got, want) 227 } 228 if c := res.Header.Get("Connection"); c != "" { 229 t.Errorf("handler got header %q = %q; want empty", "Connection", c) 230 } 231 if c := res.Header.Get(someConnHeader); c != "" { 232 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) 233 } 234 if c := res.Header.Get(fakeConnectionToken); c != "" { 235 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) 236 } 237 } 238 239 func TestXForwardedFor(t *testing.T) { 240 const prevForwardedFor = "client ip" 241 const backendResponse = "I am the backend" 242 const backendStatus = 404 243 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 244 if r.Header.Get("X-Forwarded-For") == "" { 245 t.Errorf("didn't get X-Forwarded-For header") 246 } 247 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { 248 t.Errorf("X-Forwarded-For didn't contain prior data") 249 } 250 w.WriteHeader(backendStatus) 251 w.Write([]byte(backendResponse)) 252 })) 253 defer backend.Close() 254 backendURL, err := url.Parse(backend.URL) 255 if err != nil { 256 t.Fatal(err) 257 } 258 proxyHandler := NewSingleHostReverseProxy(backendURL) 259 frontend := httptest.NewServer(proxyHandler) 260 defer frontend.Close() 261 262 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 263 getReq.Host = "some-name" 264 getReq.Header.Set("Connection", "close") 265 getReq.Header.Set("X-Forwarded-For", prevForwardedFor) 266 getReq.Close = true 267 res, err := frontend.Client().Do(getReq) 268 if err != nil { 269 t.Fatalf("Get: %v", err) 270 } 271 if g, e := res.StatusCode, backendStatus; g != e { 272 t.Errorf("got res.StatusCode %d; expected %d", g, e) 273 } 274 bodyBytes, _ := ioutil.ReadAll(res.Body) 275 if g, e := string(bodyBytes), backendResponse; g != e { 276 t.Errorf("got body %q; expected %q", g, e) 277 } 278 } 279 280 var proxyQueryTests = []struct { 281 baseSuffix string // suffix to add to backend URL 282 reqSuffix string // suffix to add to frontend's request URL 283 want string // what backend should see for final request URL (without ?) 284 }{ 285 {"", "", ""}, 286 {"?sta=tic", "?us=er", "sta=tic&us=er"}, 287 {"", "?us=er", "us=er"}, 288 {"?sta=tic", "", "sta=tic"}, 289 } 290 291 func TestReverseProxyQuery(t *testing.T) { 292 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 293 w.Header().Set("X-Got-Query", r.URL.RawQuery) 294 w.Write([]byte("hi")) 295 })) 296 defer backend.Close() 297 298 for i, tt := range proxyQueryTests { 299 backendURL, err := url.Parse(backend.URL + tt.baseSuffix) 300 if err != nil { 301 t.Fatal(err) 302 } 303 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) 304 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) 305 req.Close = true 306 res, err := frontend.Client().Do(req) 307 if err != nil { 308 t.Fatalf("%d. Get: %v", i, err) 309 } 310 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { 311 t.Errorf("%d. got query %q; expected %q", i, g, e) 312 } 313 res.Body.Close() 314 frontend.Close() 315 } 316 } 317 318 func TestReverseProxyFlushInterval(t *testing.T) { 319 const expected = "hi" 320 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 321 w.Write([]byte(expected)) 322 })) 323 defer backend.Close() 324 325 backendURL, err := url.Parse(backend.URL) 326 if err != nil { 327 t.Fatal(err) 328 } 329 330 proxyHandler := NewSingleHostReverseProxy(backendURL) 331 proxyHandler.FlushInterval = time.Microsecond 332 333 frontend := httptest.NewServer(proxyHandler) 334 defer frontend.Close() 335 336 req, _ := http.NewRequest("GET", frontend.URL, nil) 337 req.Close = true 338 res, err := frontend.Client().Do(req) 339 if err != nil { 340 t.Fatalf("Get: %v", err) 341 } 342 defer res.Body.Close() 343 if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { 344 t.Errorf("got body %q; expected %q", bodyBytes, expected) 345 } 346 } 347 348 func TestReverseProxyFlushIntervalHeaders(t *testing.T) { 349 const expected = "hi" 350 stopCh := make(chan struct{}) 351 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 352 w.Header().Add("MyHeader", expected) 353 w.WriteHeader(200) 354 w.(http.Flusher).Flush() 355 <-stopCh 356 })) 357 defer backend.Close() 358 defer close(stopCh) 359 360 backendURL, err := url.Parse(backend.URL) 361 if err != nil { 362 t.Fatal(err) 363 } 364 365 proxyHandler := NewSingleHostReverseProxy(backendURL) 366 proxyHandler.FlushInterval = time.Microsecond 367 368 frontend := httptest.NewServer(proxyHandler) 369 defer frontend.Close() 370 371 req, _ := http.NewRequest("GET", frontend.URL, nil) 372 req.Close = true 373 374 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) 375 defer cancel() 376 req = req.WithContext(ctx) 377 378 res, err := frontend.Client().Do(req) 379 if err != nil { 380 t.Fatalf("Get: %v", err) 381 } 382 defer res.Body.Close() 383 384 if res.Header.Get("MyHeader") != expected { 385 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) 386 } 387 } 388 389 func TestReverseProxyCancelation(t *testing.T) { 390 const backendResponse = "I am the backend" 391 392 reqInFlight := make(chan struct{}) 393 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 394 close(reqInFlight) // cause the client to cancel its request 395 396 select { 397 case <-time.After(10 * time.Second): 398 // Note: this should only happen in broken implementations, and the 399 // closenotify case should be instantaneous. 400 t.Error("Handler never saw CloseNotify") 401 return 402 case <-w.(http.CloseNotifier).CloseNotify(): 403 } 404 405 w.WriteHeader(http.StatusOK) 406 w.Write([]byte(backendResponse)) 407 })) 408 409 defer backend.Close() 410 411 backend.Config.ErrorLog = log.New(ioutil.Discard, "", 0) 412 413 backendURL, err := url.Parse(backend.URL) 414 if err != nil { 415 t.Fatal(err) 416 } 417 418 proxyHandler := NewSingleHostReverseProxy(backendURL) 419 420 // Discards errors of the form: 421 // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection 422 proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) 423 424 frontend := httptest.NewServer(proxyHandler) 425 defer frontend.Close() 426 frontendClient := frontend.Client() 427 428 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 429 go func() { 430 <-reqInFlight 431 frontendClient.Transport.(*http.Transport).CancelRequest(getReq) 432 }() 433 res, err := frontendClient.Do(getReq) 434 if res != nil { 435 t.Errorf("got response %v; want nil", res.Status) 436 } 437 if err == nil { 438 // This should be an error like: 439 // Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079: 440 // use of closed network connection 441 t.Error("Server.Client().Do() returned nil error; want non-nil error") 442 } 443 } 444 445 func req(t *testing.T, v string) *http.Request { 446 req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) 447 if err != nil { 448 t.Fatal(err) 449 } 450 return req 451 } 452 453 // Issue 12344 454 func TestNilBody(t *testing.T) { 455 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 456 w.Write([]byte("hi")) 457 })) 458 defer backend.Close() 459 460 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 461 backURL, _ := url.Parse(backend.URL) 462 rp := NewSingleHostReverseProxy(backURL) 463 r := req(t, "GET / HTTP/1.0\r\n\r\n") 464 r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working 465 rp.ServeHTTP(w, r) 466 })) 467 defer frontend.Close() 468 469 res, err := http.Get(frontend.URL) 470 if err != nil { 471 t.Fatal(err) 472 } 473 defer res.Body.Close() 474 slurp, err := ioutil.ReadAll(res.Body) 475 if err != nil { 476 t.Fatal(err) 477 } 478 if string(slurp) != "hi" { 479 t.Errorf("Got %q; want %q", slurp, "hi") 480 } 481 } 482 483 // Issue 15524 484 func TestUserAgentHeader(t *testing.T) { 485 const explicitUA = "explicit UA" 486 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 487 if r.URL.Path == "/noua" { 488 if c := r.Header.Get("User-Agent"); c != "" { 489 t.Errorf("handler got non-empty User-Agent header %q", c) 490 } 491 return 492 } 493 if c := r.Header.Get("User-Agent"); c != explicitUA { 494 t.Errorf("handler got unexpected User-Agent header %q", c) 495 } 496 })) 497 defer backend.Close() 498 backendURL, err := url.Parse(backend.URL) 499 if err != nil { 500 t.Fatal(err) 501 } 502 proxyHandler := NewSingleHostReverseProxy(backendURL) 503 proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests 504 frontend := httptest.NewServer(proxyHandler) 505 defer frontend.Close() 506 frontendClient := frontend.Client() 507 508 getReq, _ := http.NewRequest("GET", frontend.URL, nil) 509 getReq.Header.Set("User-Agent", explicitUA) 510 getReq.Close = true 511 res, err := frontendClient.Do(getReq) 512 if err != nil { 513 t.Fatalf("Get: %v", err) 514 } 515 res.Body.Close() 516 517 getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) 518 getReq.Header.Set("User-Agent", "") 519 getReq.Close = true 520 res, err = frontendClient.Do(getReq) 521 if err != nil { 522 t.Fatalf("Get: %v", err) 523 } 524 res.Body.Close() 525 } 526 527 type bufferPool struct { 528 get func() []byte 529 put func([]byte) 530 } 531 532 func (bp bufferPool) Get() []byte { return bp.get() } 533 func (bp bufferPool) Put(v []byte) { bp.put(v) } 534 535 func TestReverseProxyGetPutBuffer(t *testing.T) { 536 const msg = "hi" 537 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 538 io.WriteString(w, msg) 539 })) 540 defer backend.Close() 541 542 backendURL, err := url.Parse(backend.URL) 543 if err != nil { 544 t.Fatal(err) 545 } 546 547 var ( 548 mu sync.Mutex 549 log []string 550 ) 551 addLog := func(event string) { 552 mu.Lock() 553 defer mu.Unlock() 554 log = append(log, event) 555 } 556 rp := NewSingleHostReverseProxy(backendURL) 557 const size = 1234 558 rp.BufferPool = bufferPool{ 559 get: func() []byte { 560 addLog("getBuf") 561 return make([]byte, size) 562 }, 563 put: func(p []byte) { 564 addLog("putBuf-" + strconv.Itoa(len(p))) 565 }, 566 } 567 frontend := httptest.NewServer(rp) 568 defer frontend.Close() 569 570 req, _ := http.NewRequest("GET", frontend.URL, nil) 571 req.Close = true 572 res, err := frontend.Client().Do(req) 573 if err != nil { 574 t.Fatalf("Get: %v", err) 575 } 576 slurp, err := ioutil.ReadAll(res.Body) 577 res.Body.Close() 578 if err != nil { 579 t.Fatalf("reading body: %v", err) 580 } 581 if string(slurp) != msg { 582 t.Errorf("msg = %q; want %q", slurp, msg) 583 } 584 wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} 585 mu.Lock() 586 defer mu.Unlock() 587 if !reflect.DeepEqual(log, wantLog) { 588 t.Errorf("Log events = %q; want %q", log, wantLog) 589 } 590 } 591 592 func TestReverseProxy_Post(t *testing.T) { 593 const backendResponse = "I am the backend" 594 const backendStatus = 200 595 var requestBody = bytes.Repeat([]byte("a"), 1<<20) 596 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 597 slurp, err := ioutil.ReadAll(r.Body) 598 if err != nil { 599 t.Errorf("Backend body read = %v", err) 600 } 601 if len(slurp) != len(requestBody) { 602 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) 603 } 604 if !bytes.Equal(slurp, requestBody) { 605 t.Error("Backend read wrong request body.") // 1MB; omitting details 606 } 607 w.Write([]byte(backendResponse)) 608 })) 609 defer backend.Close() 610 backendURL, err := url.Parse(backend.URL) 611 if err != nil { 612 t.Fatal(err) 613 } 614 proxyHandler := NewSingleHostReverseProxy(backendURL) 615 frontend := httptest.NewServer(proxyHandler) 616 defer frontend.Close() 617 618 postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) 619 res, err := frontend.Client().Do(postReq) 620 if err != nil { 621 t.Fatalf("Do: %v", err) 622 } 623 if g, e := res.StatusCode, backendStatus; g != e { 624 t.Errorf("got res.StatusCode %d; expected %d", g, e) 625 } 626 bodyBytes, _ := ioutil.ReadAll(res.Body) 627 if g, e := string(bodyBytes), backendResponse; g != e { 628 t.Errorf("got body %q; expected %q", g, e) 629 } 630 } 631 632 type RoundTripperFunc func(*http.Request) (*http.Response, error) 633 634 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 635 return fn(req) 636 } 637 638 // Issue 16036: send a Request with a nil Body when possible 639 func TestReverseProxy_NilBody(t *testing.T) { 640 backendURL, _ := url.Parse("http://fake.tld/") 641 proxyHandler := NewSingleHostReverseProxy(backendURL) 642 proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests 643 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 644 if req.Body != nil { 645 t.Error("Body != nil; want a nil Body") 646 } 647 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 648 }) 649 frontend := httptest.NewServer(proxyHandler) 650 defer frontend.Close() 651 652 res, err := frontend.Client().Get(frontend.URL) 653 if err != nil { 654 t.Fatal(err) 655 } 656 defer res.Body.Close() 657 if res.StatusCode != 502 { 658 t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) 659 } 660 } 661 662 // Issue 33142: always allocate the request headers 663 func TestReverseProxy_AllocatedHeader(t *testing.T) { 664 proxyHandler := new(ReverseProxy) 665 proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests 666 proxyHandler.Director = func(*http.Request) {} // noop 667 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 668 if req.Header == nil { 669 t.Error("Header == nil; want a non-nil Header") 670 } 671 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") 672 }) 673 674 proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ 675 Method: "GET", 676 URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, 677 Proto: "HTTP/1.0", 678 ProtoMajor: 1, 679 }) 680 } 681 682 // Issue 14237. Test ModifyResponse and that an error from it 683 // causes the proxy to return StatusBadGateway, or StatusOK otherwise. 684 func TestReverseProxyModifyResponse(t *testing.T) { 685 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 686 w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) 687 })) 688 defer backendServer.Close() 689 690 rpURL, _ := url.Parse(backendServer.URL) 691 rproxy := NewSingleHostReverseProxy(rpURL) 692 rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests 693 rproxy.ModifyResponse = func(resp *http.Response) error { 694 if resp.Header.Get("X-Hit-Mod") != "true" { 695 return fmt.Errorf("tried to by-pass proxy") 696 } 697 return nil 698 } 699 700 frontendProxy := httptest.NewServer(rproxy) 701 defer frontendProxy.Close() 702 703 tests := []struct { 704 url string 705 wantCode int 706 }{ 707 {frontendProxy.URL + "/mod", http.StatusOK}, 708 {frontendProxy.URL + "/schedule", http.StatusBadGateway}, 709 } 710 711 for i, tt := range tests { 712 resp, err := http.Get(tt.url) 713 if err != nil { 714 t.Fatalf("failed to reach proxy: %v", err) 715 } 716 if g, e := resp.StatusCode, tt.wantCode; g != e { 717 t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) 718 } 719 resp.Body.Close() 720 } 721 } 722 723 type failingRoundTripper struct{} 724 725 func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 726 return nil, errors.New("some error") 727 } 728 729 type staticResponseRoundTripper struct{ res *http.Response } 730 731 func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { 732 return rt.res, nil 733 } 734 735 func TestReverseProxyErrorHandler(t *testing.T) { 736 tests := []struct { 737 name string 738 wantCode int 739 errorHandler func(http.ResponseWriter, *http.Request, error) 740 transport http.RoundTripper // defaults to failingRoundTripper 741 modifyResponse func(*http.Response) error 742 }{ 743 { 744 name: "default", 745 wantCode: http.StatusBadGateway, 746 }, 747 { 748 name: "errorhandler", 749 wantCode: http.StatusTeapot, 750 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 751 }, 752 { 753 name: "modifyresponse_noerr", 754 transport: staticResponseRoundTripper{ 755 &http.Response{StatusCode: 345, Body: http.NoBody}, 756 }, 757 modifyResponse: func(res *http.Response) error { 758 res.StatusCode++ 759 return nil 760 }, 761 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 762 wantCode: 346, 763 }, 764 { 765 name: "modifyresponse_err", 766 transport: staticResponseRoundTripper{ 767 &http.Response{StatusCode: 345, Body: http.NoBody}, 768 }, 769 modifyResponse: func(res *http.Response) error { 770 res.StatusCode++ 771 return errors.New("some error to trigger errorHandler") 772 }, 773 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, 774 wantCode: http.StatusTeapot, 775 }, 776 } 777 778 for _, tt := range tests { 779 t.Run(tt.name, func(t *testing.T) { 780 target := &url.URL{ 781 Scheme: "http", 782 Host: "dummy.tld", 783 Path: "/", 784 } 785 rproxy := NewSingleHostReverseProxy(target) 786 rproxy.Transport = tt.transport 787 rproxy.ModifyResponse = tt.modifyResponse 788 if rproxy.Transport == nil { 789 rproxy.Transport = failingRoundTripper{} 790 } 791 rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests 792 if tt.errorHandler != nil { 793 rproxy.ErrorHandler = tt.errorHandler 794 } 795 frontendProxy := httptest.NewServer(rproxy) 796 defer frontendProxy.Close() 797 798 resp, err := http.Get(frontendProxy.URL + "/test") 799 if err != nil { 800 t.Fatalf("failed to reach proxy: %v", err) 801 } 802 if g, e := resp.StatusCode, tt.wantCode; g != e { 803 t.Errorf("got res.StatusCode %d; expected %d", g, e) 804 } 805 resp.Body.Close() 806 }) 807 } 808 } 809 810 // Issue 16659: log errors from short read 811 func TestReverseProxy_CopyBuffer(t *testing.T) { 812 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 813 out := "this call was relayed by the reverse proxy" 814 // Coerce a wrong content length to induce io.UnexpectedEOF 815 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 816 fmt.Fprintln(w, out) 817 })) 818 defer backendServer.Close() 819 820 rpURL, err := url.Parse(backendServer.URL) 821 if err != nil { 822 t.Fatal(err) 823 } 824 825 var proxyLog bytes.Buffer 826 rproxy := NewSingleHostReverseProxy(rpURL) 827 rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) 828 donec := make(chan bool, 1) 829 frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 830 defer func() { donec <- true }() 831 rproxy.ServeHTTP(w, r) 832 })) 833 defer frontendProxy.Close() 834 835 if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { 836 t.Fatalf("want non-nil error") 837 } 838 // The race detector complains about the proxyLog usage in logf in copyBuffer 839 // and our usage below with proxyLog.Bytes() so we're explicitly using a 840 // channel to ensure that the ReverseProxy's ServeHTTP is done before we 841 // continue after Get. 842 <-donec 843 844 expected := []string{ 845 "EOF", 846 "read", 847 } 848 for _, phrase := range expected { 849 if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { 850 t.Errorf("expected log to contain phrase %q", phrase) 851 } 852 } 853 } 854 855 type staticTransport struct { 856 res *http.Response 857 } 858 859 func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { 860 return t.res, nil 861 } 862 863 func BenchmarkServeHTTP(b *testing.B) { 864 res := &http.Response{ 865 StatusCode: 200, 866 Body: ioutil.NopCloser(strings.NewReader("")), 867 } 868 proxy := &ReverseProxy{ 869 Director: func(*http.Request) {}, 870 Transport: &staticTransport{res}, 871 } 872 873 w := httptest.NewRecorder() 874 r := httptest.NewRequest("GET", "/", nil) 875 876 b.ReportAllocs() 877 for i := 0; i < b.N; i++ { 878 proxy.ServeHTTP(w, r) 879 } 880 } 881 882 func TestServeHTTPDeepCopy(t *testing.T) { 883 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 884 w.Write([]byte("Hello Gopher!")) 885 })) 886 defer backend.Close() 887 backendURL, err := url.Parse(backend.URL) 888 if err != nil { 889 t.Fatal(err) 890 } 891 892 type result struct { 893 before, after string 894 } 895 896 resultChan := make(chan result, 1) 897 proxyHandler := NewSingleHostReverseProxy(backendURL) 898 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 899 before := r.URL.String() 900 proxyHandler.ServeHTTP(w, r) 901 after := r.URL.String() 902 resultChan <- result{before: before, after: after} 903 })) 904 defer frontend.Close() 905 906 want := result{before: "/", after: "/"} 907 908 res, err := frontend.Client().Get(frontend.URL) 909 if err != nil { 910 t.Fatalf("Do: %v", err) 911 } 912 res.Body.Close() 913 914 got := <-resultChan 915 if got != want { 916 t.Errorf("got = %+v; want = %+v", got, want) 917 } 918 } 919 920 // Issue 18327: verify we always do a deep copy of the Request.Header map 921 // before any mutations. 922 func TestClonesRequestHeaders(t *testing.T) { 923 log.SetOutput(ioutil.Discard) 924 defer log.SetOutput(os.Stderr) 925 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 926 req.RemoteAddr = "1.2.3.4:56789" 927 rp := &ReverseProxy{ 928 Director: func(req *http.Request) { 929 req.Header.Set("From-Director", "1") 930 }, 931 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { 932 if v := req.Header.Get("From-Director"); v != "1" { 933 t.Errorf("From-Directory value = %q; want 1", v) 934 } 935 return nil, io.EOF 936 }), 937 } 938 rp.ServeHTTP(httptest.NewRecorder(), req) 939 940 if req.Header.Get("From-Director") == "1" { 941 t.Error("Director header mutation modified caller's request") 942 } 943 if req.Header.Get("X-Forwarded-For") != "" { 944 t.Error("X-Forward-For header mutation modified caller's request") 945 } 946 947 } 948 949 type roundTripperFunc func(req *http.Request) (*http.Response, error) 950 951 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 952 return fn(req) 953 } 954 955 func TestModifyResponseClosesBody(t *testing.T) { 956 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 957 req.RemoteAddr = "1.2.3.4:56789" 958 closeCheck := new(checkCloser) 959 logBuf := new(bytes.Buffer) 960 outErr := errors.New("ModifyResponse error") 961 rp := &ReverseProxy{ 962 Director: func(req *http.Request) {}, 963 Transport: &staticTransport{&http.Response{ 964 StatusCode: 200, 965 Body: closeCheck, 966 }}, 967 ErrorLog: log.New(logBuf, "", 0), 968 ModifyResponse: func(*http.Response) error { 969 return outErr 970 }, 971 } 972 rec := httptest.NewRecorder() 973 rp.ServeHTTP(rec, req) 974 res := rec.Result() 975 if g, e := res.StatusCode, http.StatusBadGateway; g != e { 976 t.Errorf("got res.StatusCode %d; expected %d", g, e) 977 } 978 if !closeCheck.closed { 979 t.Errorf("body should have been closed") 980 } 981 if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { 982 t.Errorf("ErrorLog %q does not contain %q", g, e) 983 } 984 } 985 986 type checkCloser struct { 987 closed bool 988 } 989 990 func (cc *checkCloser) Close() error { 991 cc.closed = true 992 return nil 993 } 994 995 func (cc *checkCloser) Read(b []byte) (int, error) { 996 return len(b), nil 997 } 998 999 // Issue 23643: panic on body copy error 1000 func TestReverseProxy_PanicBodyError(t *testing.T) { 1001 log.SetOutput(ioutil.Discard) 1002 defer log.SetOutput(os.Stderr) 1003 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1004 out := "this call was relayed by the reverse proxy" 1005 // Coerce a wrong content length to induce io.ErrUnexpectedEOF 1006 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) 1007 fmt.Fprintln(w, out) 1008 })) 1009 defer backendServer.Close() 1010 1011 rpURL, err := url.Parse(backendServer.URL) 1012 if err != nil { 1013 t.Fatal(err) 1014 } 1015 1016 rproxy := NewSingleHostReverseProxy(rpURL) 1017 1018 // Ensure that the handler panics when the body read encounters an 1019 // io.ErrUnexpectedEOF 1020 defer func() { 1021 err := recover() 1022 if err == nil { 1023 t.Fatal("handler should have panicked") 1024 } 1025 if err != http.ErrAbortHandler { 1026 t.Fatal("expected ErrAbortHandler, got", err) 1027 } 1028 }() 1029 req, _ := http.NewRequest("GET", "http://foo.tld/", nil) 1030 rproxy.ServeHTTP(httptest.NewRecorder(), req) 1031 } 1032 1033 func TestSelectFlushInterval(t *testing.T) { 1034 tests := []struct { 1035 name string 1036 p *ReverseProxy 1037 req *http.Request 1038 res *http.Response 1039 want time.Duration 1040 }{ 1041 { 1042 name: "default", 1043 res: &http.Response{}, 1044 p: &ReverseProxy{FlushInterval: 123}, 1045 want: 123, 1046 }, 1047 { 1048 name: "server-sent events overrides non-zero", 1049 res: &http.Response{ 1050 Header: http.Header{ 1051 "Content-Type": {"text/event-stream"}, 1052 }, 1053 }, 1054 p: &ReverseProxy{FlushInterval: 123}, 1055 want: -1, 1056 }, 1057 { 1058 name: "server-sent events overrides zero", 1059 res: &http.Response{ 1060 Header: http.Header{ 1061 "Content-Type": {"text/event-stream"}, 1062 }, 1063 }, 1064 p: &ReverseProxy{FlushInterval: 0}, 1065 want: -1, 1066 }, 1067 } 1068 for _, tt := range tests { 1069 t.Run(tt.name, func(t *testing.T) { 1070 got := tt.p.flushInterval(tt.req, tt.res) 1071 if got != tt.want { 1072 t.Errorf("flushLatency = %v; want %v", got, tt.want) 1073 } 1074 }) 1075 } 1076 } 1077 1078 func TestReverseProxyWebSocket(t *testing.T) { 1079 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1080 if upgradeType(r.Header) != "websocket" { 1081 t.Error("unexpected backend request") 1082 http.Error(w, "unexpected request", 400) 1083 return 1084 } 1085 c, _, err := w.(http.Hijacker).Hijack() 1086 if err != nil { 1087 t.Error(err) 1088 return 1089 } 1090 defer c.Close() 1091 io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") 1092 bs := bufio.NewScanner(c) 1093 if !bs.Scan() { 1094 t.Errorf("backend failed to read line from client: %v", bs.Err()) 1095 return 1096 } 1097 fmt.Fprintf(c, "backend got %q\n", bs.Text()) 1098 })) 1099 defer backendServer.Close() 1100 1101 backURL, _ := url.Parse(backendServer.URL) 1102 rproxy := NewSingleHostReverseProxy(backURL) 1103 rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests 1104 rproxy.ModifyResponse = func(res *http.Response) error { 1105 res.Header.Add("X-Modified", "true") 1106 return nil 1107 } 1108 1109 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 1110 rw.Header().Set("X-Header", "X-Value") 1111 rproxy.ServeHTTP(rw, req) 1112 }) 1113 1114 frontendProxy := httptest.NewServer(handler) 1115 defer frontendProxy.Close() 1116 1117 req, _ := http.NewRequest("GET", frontendProxy.URL, nil) 1118 req.Header.Set("Connection", "Upgrade") 1119 req.Header.Set("Upgrade", "websocket") 1120 1121 c := frontendProxy.Client() 1122 res, err := c.Do(req) 1123 if err != nil { 1124 t.Fatal(err) 1125 } 1126 if res.StatusCode != 101 { 1127 t.Fatalf("status = %v; want 101", res.Status) 1128 } 1129 1130 got := res.Header.Get("X-Header") 1131 want := "X-Value" 1132 if got != want { 1133 t.Errorf("Header(XHeader) = %q; want %q", got, want) 1134 } 1135 1136 if upgradeType(res.Header) != "websocket" { 1137 t.Fatalf("not websocket upgrade; got %#v", res.Header) 1138 } 1139 rwc, ok := res.Body.(io.ReadWriteCloser) 1140 if !ok { 1141 t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) 1142 } 1143 defer rwc.Close() 1144 1145 if got, want := res.Header.Get("X-Modified"), "true"; got != want { 1146 t.Errorf("response X-Modified header = %q; want %q", got, want) 1147 } 1148 1149 io.WriteString(rwc, "Hello\n") 1150 bs := bufio.NewScanner(rwc) 1151 if !bs.Scan() { 1152 t.Fatalf("Scan: %v", bs.Err()) 1153 } 1154 got = bs.Text() 1155 want = `backend got "Hello"` 1156 if got != want { 1157 t.Errorf("got %#q, want %#q", got, want) 1158 } 1159 } 1160 1161 func TestUnannouncedTrailer(t *testing.T) { 1162 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1163 w.WriteHeader(http.StatusOK) 1164 w.(http.Flusher).Flush() 1165 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") 1166 })) 1167 defer backend.Close() 1168 backendURL, err := url.Parse(backend.URL) 1169 if err != nil { 1170 t.Fatal(err) 1171 } 1172 proxyHandler := NewSingleHostReverseProxy(backendURL) 1173 proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests 1174 frontend := httptest.NewServer(proxyHandler) 1175 defer frontend.Close() 1176 frontendClient := frontend.Client() 1177 1178 res, err := frontendClient.Get(frontend.URL) 1179 if err != nil { 1180 t.Fatalf("Get: %v", err) 1181 } 1182 1183 ioutil.ReadAll(res.Body) 1184 1185 if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { 1186 t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) 1187 } 1188 1189 } 1190 1191 func TestSingleJoinSlash(t *testing.T) { 1192 tests := []struct { 1193 slasha string 1194 slashb string 1195 expected string 1196 }{ 1197 {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1198 {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, 1199 {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, 1200 {"https://www.google.com", "", "https://www.google.com/"}, 1201 {"", "favicon.ico", "/favicon.ico"}, 1202 } 1203 for _, tt := range tests { 1204 if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { 1205 t.Errorf("singleJoiningSlash(%s,%s) want %s got %s", 1206 tt.slasha, 1207 tt.slashb, 1208 tt.expected, 1209 got) 1210 } 1211 } 1212 } 1213
View as plain text