1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "fmt"
12 "io"
13 "log"
14 "net"
15 "net/http"
16 "net/textproto"
17 "net/url"
18 "strings"
19 "sync"
20 "time"
21
22 "golang.org/x/net/http/httpguts"
23 )
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41 type ReverseProxy struct {
42
43
44
45
46
47
48 Director func(*http.Request)
49
50
51
52 Transport http.RoundTripper
53
54
55
56
57
58
59
60
61
62
63
64 FlushInterval time.Duration
65
66
67
68
69 ErrorLog *log.Logger
70
71
72
73
74 BufferPool BufferPool
75
76
77
78
79
80
81
82
83
84
85 ModifyResponse func(*http.Response) error
86
87
88
89
90
91
92 ErrorHandler func(http.ResponseWriter, *http.Request, error)
93 }
94
95
96
97 type BufferPool interface {
98 Get() []byte
99 Put([]byte)
100 }
101
102 func singleJoiningSlash(a, b string) string {
103 aslash := strings.HasSuffix(a, "/")
104 bslash := strings.HasPrefix(b, "/")
105 switch {
106 case aslash && bslash:
107 return a + b[1:]
108 case !aslash && !bslash:
109 return a + "/" + b
110 }
111 return a + b
112 }
113
114 func joinURLPath(a, b *url.URL) (path, rawpath string) {
115 if a.RawPath == "" && b.RawPath == "" {
116 return singleJoiningSlash(a.Path, b.Path), ""
117 }
118
119
120 apath := a.EscapedPath()
121 bpath := b.EscapedPath()
122
123 aslash := strings.HasSuffix(apath, "/")
124 bslash := strings.HasPrefix(bpath, "/")
125
126 switch {
127 case aslash && bslash:
128 return a.Path + b.Path[1:], apath + bpath[1:]
129 case !aslash && !bslash:
130 return a.Path + "/" + b.Path, apath + "/" + bpath
131 }
132 return a.Path + b.Path, apath + bpath
133 }
134
135
136
137
138
139
140
141
142 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
143 targetQuery := target.RawQuery
144 director := func(req *http.Request) {
145 req.URL.Scheme = target.Scheme
146 req.URL.Host = target.Host
147 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
148 if targetQuery == "" || req.URL.RawQuery == "" {
149 req.URL.RawQuery = targetQuery + req.URL.RawQuery
150 } else {
151 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
152 }
153 if _, ok := req.Header["User-Agent"]; !ok {
154
155 req.Header.Set("User-Agent", "")
156 }
157 }
158 return &ReverseProxy{Director: director}
159 }
160
161 func copyHeader(dst, src http.Header) {
162 for k, vv := range src {
163 for _, v := range vv {
164 dst.Add(k, v)
165 }
166 }
167 }
168
169
170
171
172
173
174 var hopHeaders = []string{
175 "Connection",
176 "Proxy-Connection",
177 "Keep-Alive",
178 "Proxy-Authenticate",
179 "Proxy-Authorization",
180 "Te",
181 "Trailer",
182 "Transfer-Encoding",
183 "Upgrade",
184 }
185
186 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
187 p.logf("http: proxy error: %v", err)
188 rw.WriteHeader(http.StatusBadGateway)
189 }
190
191 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
192 if p.ErrorHandler != nil {
193 return p.ErrorHandler
194 }
195 return p.defaultErrorHandler
196 }
197
198
199
200 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
201 if p.ModifyResponse == nil {
202 return true
203 }
204 if err := p.ModifyResponse(res); err != nil {
205 res.Body.Close()
206 p.getErrorHandler()(rw, req, err)
207 return false
208 }
209 return true
210 }
211
212 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
213 transport := p.Transport
214 if transport == nil {
215 transport = http.DefaultTransport
216 }
217
218 ctx := req.Context()
219 if cn, ok := rw.(http.CloseNotifier); ok {
220 var cancel context.CancelFunc
221 ctx, cancel = context.WithCancel(ctx)
222 defer cancel()
223 notifyChan := cn.CloseNotify()
224 go func() {
225 select {
226 case <-notifyChan:
227 cancel()
228 case <-ctx.Done():
229 }
230 }()
231 }
232
233 outreq := req.Clone(ctx)
234 if req.ContentLength == 0 {
235 outreq.Body = nil
236 }
237 if outreq.Header == nil {
238 outreq.Header = make(http.Header)
239 }
240
241 p.Director(outreq)
242 outreq.Close = false
243
244 reqUpType := upgradeType(outreq.Header)
245 removeConnectionHeaders(outreq.Header)
246
247
248
249
250 for _, h := range hopHeaders {
251 hv := outreq.Header.Get(h)
252 if hv == "" {
253 continue
254 }
255 if h == "Te" && hv == "trailers" {
256
257
258
259
260
261
262 continue
263 }
264 outreq.Header.Del(h)
265 }
266
267
268
269 if reqUpType != "" {
270 outreq.Header.Set("Connection", "Upgrade")
271 outreq.Header.Set("Upgrade", reqUpType)
272 }
273
274 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
275
276
277
278 prior, ok := outreq.Header["X-Forwarded-For"]
279 omit := ok && prior == nil
280 if len(prior) > 0 {
281 clientIP = strings.Join(prior, ", ") + ", " + clientIP
282 }
283 if !omit {
284 outreq.Header.Set("X-Forwarded-For", clientIP)
285 }
286 }
287
288 res, err := transport.RoundTrip(outreq)
289 if err != nil {
290 p.getErrorHandler()(rw, outreq, err)
291 return
292 }
293
294
295 if res.StatusCode == http.StatusSwitchingProtocols {
296 if !p.modifyResponse(rw, res, outreq) {
297 return
298 }
299 p.handleUpgradeResponse(rw, outreq, res)
300 return
301 }
302
303 removeConnectionHeaders(res.Header)
304
305 for _, h := range hopHeaders {
306 res.Header.Del(h)
307 }
308
309 if !p.modifyResponse(rw, res, outreq) {
310 return
311 }
312
313 copyHeader(rw.Header(), res.Header)
314
315
316
317 announcedTrailers := len(res.Trailer)
318 if announcedTrailers > 0 {
319 trailerKeys := make([]string, 0, len(res.Trailer))
320 for k := range res.Trailer {
321 trailerKeys = append(trailerKeys, k)
322 }
323 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
324 }
325
326 rw.WriteHeader(res.StatusCode)
327
328 err = p.copyResponse(rw, res.Body, p.flushInterval(res))
329 if err != nil {
330 defer res.Body.Close()
331
332
333
334 if !shouldPanicOnCopyError(req) {
335 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
336 return
337 }
338 panic(http.ErrAbortHandler)
339 }
340 res.Body.Close()
341
342 if len(res.Trailer) > 0 {
343
344
345
346 if fl, ok := rw.(http.Flusher); ok {
347 fl.Flush()
348 }
349 }
350
351 if len(res.Trailer) == announcedTrailers {
352 copyHeader(rw.Header(), res.Trailer)
353 return
354 }
355
356 for k, vv := range res.Trailer {
357 k = http.TrailerPrefix + k
358 for _, v := range vv {
359 rw.Header().Add(k, v)
360 }
361 }
362 }
363
364 var inOurTests bool
365
366
367
368
369
370
371 func shouldPanicOnCopyError(req *http.Request) bool {
372 if inOurTests {
373
374 return true
375 }
376 if req.Context().Value(http.ServerContextKey) != nil {
377
378
379 return true
380 }
381
382
383 return false
384 }
385
386
387
388 func removeConnectionHeaders(h http.Header) {
389 for _, f := range h["Connection"] {
390 for _, sf := range strings.Split(f, ",") {
391 if sf = textproto.TrimString(sf); sf != "" {
392 h.Del(sf)
393 }
394 }
395 }
396 }
397
398
399
400 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
401 resCT := res.Header.Get("Content-Type")
402
403
404
405 if resCT == "text/event-stream" {
406 return -1
407 }
408
409
410 if res.ContentLength == -1 {
411 return -1
412 }
413
414 return p.FlushInterval
415 }
416
417 func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
418 if flushInterval != 0 {
419 if wf, ok := dst.(writeFlusher); ok {
420 mlw := &maxLatencyWriter{
421 dst: wf,
422 latency: flushInterval,
423 }
424 defer mlw.stop()
425
426
427 mlw.flushPending = true
428 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
429
430 dst = mlw
431 }
432 }
433
434 var buf []byte
435 if p.BufferPool != nil {
436 buf = p.BufferPool.Get()
437 defer p.BufferPool.Put(buf)
438 }
439 _, err := p.copyBuffer(dst, src, buf)
440 return err
441 }
442
443
444
445 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
446 if len(buf) == 0 {
447 buf = make([]byte, 32*1024)
448 }
449 var written int64
450 for {
451 nr, rerr := src.Read(buf)
452 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
453 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
454 }
455 if nr > 0 {
456 nw, werr := dst.Write(buf[:nr])
457 if nw > 0 {
458 written += int64(nw)
459 }
460 if werr != nil {
461 return written, werr
462 }
463 if nr != nw {
464 return written, io.ErrShortWrite
465 }
466 }
467 if rerr != nil {
468 if rerr == io.EOF {
469 rerr = nil
470 }
471 return written, rerr
472 }
473 }
474 }
475
476 func (p *ReverseProxy) logf(format string, args ...interface{}) {
477 if p.ErrorLog != nil {
478 p.ErrorLog.Printf(format, args...)
479 } else {
480 log.Printf(format, args...)
481 }
482 }
483
484 type writeFlusher interface {
485 io.Writer
486 http.Flusher
487 }
488
489 type maxLatencyWriter struct {
490 dst writeFlusher
491 latency time.Duration
492
493 mu sync.Mutex
494 t *time.Timer
495 flushPending bool
496 }
497
498 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
499 m.mu.Lock()
500 defer m.mu.Unlock()
501 n, err = m.dst.Write(p)
502 if m.latency < 0 {
503 m.dst.Flush()
504 return
505 }
506 if m.flushPending {
507 return
508 }
509 if m.t == nil {
510 m.t = time.AfterFunc(m.latency, m.delayedFlush)
511 } else {
512 m.t.Reset(m.latency)
513 }
514 m.flushPending = true
515 return
516 }
517
518 func (m *maxLatencyWriter) delayedFlush() {
519 m.mu.Lock()
520 defer m.mu.Unlock()
521 if !m.flushPending {
522 return
523 }
524 m.dst.Flush()
525 m.flushPending = false
526 }
527
528 func (m *maxLatencyWriter) stop() {
529 m.mu.Lock()
530 defer m.mu.Unlock()
531 m.flushPending = false
532 if m.t != nil {
533 m.t.Stop()
534 }
535 }
536
537 func upgradeType(h http.Header) string {
538 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
539 return ""
540 }
541 return strings.ToLower(h.Get("Upgrade"))
542 }
543
544 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
545 reqUpType := upgradeType(req.Header)
546 resUpType := upgradeType(res.Header)
547 if reqUpType != resUpType {
548 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
549 return
550 }
551
552 hj, ok := rw.(http.Hijacker)
553 if !ok {
554 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
555 return
556 }
557 backConn, ok := res.Body.(io.ReadWriteCloser)
558 if !ok {
559 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
560 return
561 }
562
563 backConnCloseCh := make(chan bool)
564 go func() {
565
566
567 select {
568 case <-req.Context().Done():
569 case <-backConnCloseCh:
570 }
571 backConn.Close()
572 }()
573
574 defer close(backConnCloseCh)
575
576 conn, brw, err := hj.Hijack()
577 if err != nil {
578 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
579 return
580 }
581 defer conn.Close()
582
583 copyHeader(rw.Header(), res.Header)
584
585 res.Header = rw.Header()
586 res.Body = nil
587 if err := res.Write(brw); err != nil {
588 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
589 return
590 }
591 if err := brw.Flush(); err != nil {
592 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
593 return
594 }
595 errc := make(chan error, 1)
596 spc := switchProtocolCopier{user: conn, backend: backConn}
597 go spc.copyToBackend(errc)
598 go spc.copyFromBackend(errc)
599 <-errc
600 return
601 }
602
603
604
605 type switchProtocolCopier struct {
606 user, backend io.ReadWriter
607 }
608
609 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
610 _, err := io.Copy(c.user, c.backend)
611 errc <- err
612 }
613
614 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
615 _, err := io.Copy(c.backend, c.user)
616 errc <- err
617 }
618
View as plain text