1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "fmt"
12 "io"
13 "log"
14 "net"
15 "net/http"
16 "net/textproto"
17 "net/url"
18 "strings"
19 "sync"
20 "time"
21
22 "golang.org/x/net/http/httpguts"
23 )
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41 type ReverseProxy struct {
42
43
44
45
46
47
48 Director func(*http.Request)
49
50
51
52 Transport http.RoundTripper
53
54
55
56
57
58
59
60
61
62
63
64 FlushInterval time.Duration
65
66
67
68
69 ErrorLog *log.Logger
70
71
72
73
74 BufferPool BufferPool
75
76
77
78
79
80
81
82
83
84
85 ModifyResponse func(*http.Response) error
86
87
88
89
90
91
92 ErrorHandler func(http.ResponseWriter, *http.Request, error)
93 }
94
95
96
97 type BufferPool interface {
98 Get() []byte
99 Put([]byte)
100 }
101
102 func singleJoiningSlash(a, b string) string {
103 aslash := strings.HasSuffix(a, "/")
104 bslash := strings.HasPrefix(b, "/")
105 switch {
106 case aslash && bslash:
107 return a + b[1:]
108 case !aslash && !bslash:
109 return a + "/" + b
110 }
111 return a + b
112 }
113
114 func joinURLPath(a, b *url.URL) (path, rawpath string) {
115 if a.RawPath == "" && b.RawPath == "" {
116 return singleJoiningSlash(a.Path, b.Path), ""
117 }
118
119
120 apath := a.EscapedPath()
121 bpath := b.EscapedPath()
122
123 aslash := strings.HasSuffix(apath, "/")
124 bslash := strings.HasPrefix(bpath, "/")
125
126 switch {
127 case aslash && bslash:
128 return a.Path + b.Path[1:], apath + bpath[1:]
129 case !aslash && !bslash:
130 return a.Path + "/" + b.Path, apath + "/" + bpath
131 }
132 return a.Path + b.Path, apath + bpath
133 }
134
135
136
137
138
139
140
141
142 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
143 targetQuery := target.RawQuery
144 director := func(req *http.Request) {
145 req.URL.Scheme = target.Scheme
146 req.URL.Host = target.Host
147 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
148 if targetQuery == "" || req.URL.RawQuery == "" {
149 req.URL.RawQuery = targetQuery + req.URL.RawQuery
150 } else {
151 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
152 }
153 if _, ok := req.Header["User-Agent"]; !ok {
154
155 req.Header.Set("User-Agent", "")
156 }
157 }
158 return &ReverseProxy{Director: director}
159 }
160
161 func copyHeader(dst, src http.Header) {
162 for k, vv := range src {
163 for _, v := range vv {
164 dst.Add(k, v)
165 }
166 }
167 }
168
169
170
171
172
173
174 var hopHeaders = []string{
175 "Connection",
176 "Proxy-Connection",
177 "Keep-Alive",
178 "Proxy-Authenticate",
179 "Proxy-Authorization",
180 "Te",
181 "Trailer",
182 "Transfer-Encoding",
183 "Upgrade",
184 }
185
186 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
187 p.logf("http: proxy error: %v", err)
188 rw.WriteHeader(http.StatusBadGateway)
189 }
190
191 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
192 if p.ErrorHandler != nil {
193 return p.ErrorHandler
194 }
195 return p.defaultErrorHandler
196 }
197
198
199
200 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
201 if p.ModifyResponse == nil {
202 return true
203 }
204 if err := p.ModifyResponse(res); err != nil {
205 res.Body.Close()
206 p.getErrorHandler()(rw, req, err)
207 return false
208 }
209 return true
210 }
211
212 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
213 transport := p.Transport
214 if transport == nil {
215 transport = http.DefaultTransport
216 }
217
218 ctx := req.Context()
219 if cn, ok := rw.(http.CloseNotifier); ok {
220 var cancel context.CancelFunc
221 ctx, cancel = context.WithCancel(ctx)
222 defer cancel()
223 notifyChan := cn.CloseNotify()
224 go func() {
225 select {
226 case <-notifyChan:
227 cancel()
228 case <-ctx.Done():
229 }
230 }()
231 }
232
233 outreq := req.Clone(ctx)
234 if req.ContentLength == 0 {
235 outreq.Body = nil
236 }
237 if outreq.Header == nil {
238 outreq.Header = make(http.Header)
239 }
240
241 p.Director(outreq)
242 outreq.Close = false
243
244 reqUpType := upgradeType(outreq.Header)
245 removeConnectionHeaders(outreq.Header)
246
247
248
249
250 for _, h := range hopHeaders {
251 hv := outreq.Header.Get(h)
252 if hv == "" {
253 continue
254 }
255 if h == "Te" && hv == "trailers" {
256
257
258
259
260
261
262 continue
263 }
264 outreq.Header.Del(h)
265 }
266
267
268
269 if reqUpType != "" {
270 outreq.Header.Set("Connection", "Upgrade")
271 outreq.Header.Set("Upgrade", reqUpType)
272 }
273
274 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
275
276
277
278 prior, ok := outreq.Header["X-Forwarded-For"]
279 omit := ok && prior == nil
280 if len(prior) > 0 {
281 clientIP = strings.Join(prior, ", ") + ", " + clientIP
282 }
283 if !omit {
284 outreq.Header.Set("X-Forwarded-For", clientIP)
285 }
286 }
287
288 res, err := transport.RoundTrip(outreq)
289 if err != nil {
290 p.getErrorHandler()(rw, outreq, err)
291 return
292 }
293
294
295 if res.StatusCode == http.StatusSwitchingProtocols {
296 if !p.modifyResponse(rw, res, outreq) {
297 return
298 }
299 p.handleUpgradeResponse(rw, outreq, res)
300 return
301 }
302
303 removeConnectionHeaders(res.Header)
304
305 for _, h := range hopHeaders {
306 res.Header.Del(h)
307 }
308
309 if !p.modifyResponse(rw, res, outreq) {
310 return
311 }
312
313 copyHeader(rw.Header(), res.Header)
314
315
316
317 announcedTrailers := len(res.Trailer)
318 if announcedTrailers > 0 {
319 trailerKeys := make([]string, 0, len(res.Trailer))
320 for k := range res.Trailer {
321 trailerKeys = append(trailerKeys, k)
322 }
323 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
324 }
325
326 rw.WriteHeader(res.StatusCode)
327
328 err = p.copyResponse(rw, res.Body, p.flushInterval(req, res))
329 if err != nil {
330 defer res.Body.Close()
331
332
333
334 if !shouldPanicOnCopyError(req) {
335 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
336 return
337 }
338 panic(http.ErrAbortHandler)
339 }
340 res.Body.Close()
341
342 if len(res.Trailer) > 0 {
343
344
345
346 if fl, ok := rw.(http.Flusher); ok {
347 fl.Flush()
348 }
349 }
350
351 if len(res.Trailer) == announcedTrailers {
352 copyHeader(rw.Header(), res.Trailer)
353 return
354 }
355
356 for k, vv := range res.Trailer {
357 k = http.TrailerPrefix + k
358 for _, v := range vv {
359 rw.Header().Add(k, v)
360 }
361 }
362 }
363
364 var inOurTests bool
365
366
367
368
369
370
371 func shouldPanicOnCopyError(req *http.Request) bool {
372 if inOurTests {
373
374 return true
375 }
376 if req.Context().Value(http.ServerContextKey) != nil {
377
378
379 return true
380 }
381
382
383 return false
384 }
385
386
387
388 func removeConnectionHeaders(h http.Header) {
389 for _, f := range h["Connection"] {
390 for _, sf := range strings.Split(f, ",") {
391 if sf = textproto.TrimString(sf); sf != "" {
392 h.Del(sf)
393 }
394 }
395 }
396 }
397
398
399
400 func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration {
401 resCT := res.Header.Get("Content-Type")
402
403
404
405 if resCT == "text/event-stream" {
406 return -1
407 }
408
409
410 return p.FlushInterval
411 }
412
413 func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
414 if flushInterval != 0 {
415 if wf, ok := dst.(writeFlusher); ok {
416 mlw := &maxLatencyWriter{
417 dst: wf,
418 latency: flushInterval,
419 }
420 defer mlw.stop()
421
422
423 mlw.flushPending = true
424 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
425
426 dst = mlw
427 }
428 }
429
430 var buf []byte
431 if p.BufferPool != nil {
432 buf = p.BufferPool.Get()
433 defer p.BufferPool.Put(buf)
434 }
435 _, err := p.copyBuffer(dst, src, buf)
436 return err
437 }
438
439
440
441 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
442 if len(buf) == 0 {
443 buf = make([]byte, 32*1024)
444 }
445 var written int64
446 for {
447 nr, rerr := src.Read(buf)
448 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
449 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
450 }
451 if nr > 0 {
452 nw, werr := dst.Write(buf[:nr])
453 if nw > 0 {
454 written += int64(nw)
455 }
456 if werr != nil {
457 return written, werr
458 }
459 if nr != nw {
460 return written, io.ErrShortWrite
461 }
462 }
463 if rerr != nil {
464 if rerr == io.EOF {
465 rerr = nil
466 }
467 return written, rerr
468 }
469 }
470 }
471
472 func (p *ReverseProxy) logf(format string, args ...interface{}) {
473 if p.ErrorLog != nil {
474 p.ErrorLog.Printf(format, args...)
475 } else {
476 log.Printf(format, args...)
477 }
478 }
479
480 type writeFlusher interface {
481 io.Writer
482 http.Flusher
483 }
484
485 type maxLatencyWriter struct {
486 dst writeFlusher
487 latency time.Duration
488
489 mu sync.Mutex
490 t *time.Timer
491 flushPending bool
492 }
493
494 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
495 m.mu.Lock()
496 defer m.mu.Unlock()
497 n, err = m.dst.Write(p)
498 if m.latency < 0 {
499 m.dst.Flush()
500 return
501 }
502 if m.flushPending {
503 return
504 }
505 if m.t == nil {
506 m.t = time.AfterFunc(m.latency, m.delayedFlush)
507 } else {
508 m.t.Reset(m.latency)
509 }
510 m.flushPending = true
511 return
512 }
513
514 func (m *maxLatencyWriter) delayedFlush() {
515 m.mu.Lock()
516 defer m.mu.Unlock()
517 if !m.flushPending {
518 return
519 }
520 m.dst.Flush()
521 m.flushPending = false
522 }
523
524 func (m *maxLatencyWriter) stop() {
525 m.mu.Lock()
526 defer m.mu.Unlock()
527 m.flushPending = false
528 if m.t != nil {
529 m.t.Stop()
530 }
531 }
532
533 func upgradeType(h http.Header) string {
534 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
535 return ""
536 }
537 return strings.ToLower(h.Get("Upgrade"))
538 }
539
540 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
541 reqUpType := upgradeType(req.Header)
542 resUpType := upgradeType(res.Header)
543 if reqUpType != resUpType {
544 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
545 return
546 }
547
548 copyHeader(res.Header, rw.Header())
549
550 hj, ok := rw.(http.Hijacker)
551 if !ok {
552 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
553 return
554 }
555 backConn, ok := res.Body.(io.ReadWriteCloser)
556 if !ok {
557 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
558 return
559 }
560
561 backConnCloseCh := make(chan bool)
562 go func() {
563
564
565 select {
566 case <-req.Context().Done():
567 case <-backConnCloseCh:
568 }
569 backConn.Close()
570 }()
571
572 defer close(backConnCloseCh)
573
574 conn, brw, err := hj.Hijack()
575 if err != nil {
576 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
577 return
578 }
579 defer conn.Close()
580 res.Body = nil
581 if err := res.Write(brw); err != nil {
582 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
583 return
584 }
585 if err := brw.Flush(); err != nil {
586 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
587 return
588 }
589 errc := make(chan error, 1)
590 spc := switchProtocolCopier{user: conn, backend: backConn}
591 go spc.copyToBackend(errc)
592 go spc.copyFromBackend(errc)
593 <-errc
594 return
595 }
596
597
598
599 type switchProtocolCopier struct {
600 user, backend io.ReadWriter
601 }
602
603 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
604 _, err := io.Copy(c.user, c.backend)
605 errc <- err
606 }
607
608 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
609 _, err := io.Copy(c.backend, c.user)
610 errc <- err
611 }
612
View as plain text