...
Run Format

Source file src/net/http/httptest/recorder.go

     1	// Copyright 2011 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	package httptest
     6	
     7	import (
     8		"bytes"
     9		"io/ioutil"
    10		"net/http"
    11		"strconv"
    12		"strings"
    13	)
    14	
    15	// ResponseRecorder is an implementation of http.ResponseWriter that
    16	// records its mutations for later inspection in tests.
    17	type ResponseRecorder struct {
    18		// Code is the HTTP response code set by WriteHeader.
    19		//
    20		// Note that if a Handler never calls WriteHeader or Write,
    21		// this might end up being 0, rather than the implicit
    22		// http.StatusOK. To get the implicit value, use the Result
    23		// method.
    24		Code int
    25	
    26		// HeaderMap contains the headers explicitly set by the Handler.
    27		//
    28		// To get the implicit headers set by the server (such as
    29		// automatic Content-Type), use the Result method.
    30		HeaderMap http.Header
    31	
    32		// Body is the buffer to which the Handler's Write calls are sent.
    33		// If nil, the Writes are silently discarded.
    34		Body *bytes.Buffer
    35	
    36		// Flushed is whether the Handler called Flush.
    37		Flushed bool
    38	
    39		result      *http.Response // cache of Result's return value
    40		snapHeader  http.Header    // snapshot of HeaderMap at first Write
    41		wroteHeader bool
    42	}
    43	
    44	// NewRecorder returns an initialized ResponseRecorder.
    45	func NewRecorder() *ResponseRecorder {
    46		return &ResponseRecorder{
    47			HeaderMap: make(http.Header),
    48			Body:      new(bytes.Buffer),
    49			Code:      200,
    50		}
    51	}
    52	
    53	// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
    54	// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
    55	const DefaultRemoteAddr = "1.2.3.4"
    56	
    57	// Header returns the response headers.
    58	func (rw *ResponseRecorder) Header() http.Header {
    59		m := rw.HeaderMap
    60		if m == nil {
    61			m = make(http.Header)
    62			rw.HeaderMap = m
    63		}
    64		return m
    65	}
    66	
    67	// writeHeader writes a header if it was not written yet and
    68	// detects Content-Type if needed.
    69	//
    70	// bytes or str are the beginning of the response body.
    71	// We pass both to avoid unnecessarily generate garbage
    72	// in rw.WriteString which was created for performance reasons.
    73	// Non-nil bytes win.
    74	func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
    75		if rw.wroteHeader {
    76			return
    77		}
    78		if len(str) > 512 {
    79			str = str[:512]
    80		}
    81	
    82		m := rw.Header()
    83	
    84		_, hasType := m["Content-Type"]
    85		hasTE := m.Get("Transfer-Encoding") != ""
    86		if !hasType && !hasTE {
    87			if b == nil {
    88				b = []byte(str)
    89			}
    90			m.Set("Content-Type", http.DetectContentType(b))
    91		}
    92	
    93		rw.WriteHeader(200)
    94	}
    95	
    96	// Write always succeeds and writes to rw.Body, if not nil.
    97	func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
    98		rw.writeHeader(buf, "")
    99		if rw.Body != nil {
   100			rw.Body.Write(buf)
   101		}
   102		return len(buf), nil
   103	}
   104	
   105	// WriteString always succeeds and writes to rw.Body, if not nil.
   106	func (rw *ResponseRecorder) WriteString(str string) (int, error) {
   107		rw.writeHeader(nil, str)
   108		if rw.Body != nil {
   109			rw.Body.WriteString(str)
   110		}
   111		return len(str), nil
   112	}
   113	
   114	// WriteHeader sets rw.Code. After it is called, changing rw.Header
   115	// will not affect rw.HeaderMap.
   116	func (rw *ResponseRecorder) WriteHeader(code int) {
   117		if rw.wroteHeader {
   118			return
   119		}
   120		rw.Code = code
   121		rw.wroteHeader = true
   122		if rw.HeaderMap == nil {
   123			rw.HeaderMap = make(http.Header)
   124		}
   125		rw.snapHeader = cloneHeader(rw.HeaderMap)
   126	}
   127	
   128	func cloneHeader(h http.Header) http.Header {
   129		h2 := make(http.Header, len(h))
   130		for k, vv := range h {
   131			vv2 := make([]string, len(vv))
   132			copy(vv2, vv)
   133			h2[k] = vv2
   134		}
   135		return h2
   136	}
   137	
   138	// Flush sets rw.Flushed to true.
   139	func (rw *ResponseRecorder) Flush() {
   140		if !rw.wroteHeader {
   141			rw.WriteHeader(200)
   142		}
   143		rw.Flushed = true
   144	}
   145	
   146	// Result returns the response generated by the handler.
   147	//
   148	// The returned Response will have at least its StatusCode,
   149	// Header, Body, and optionally Trailer populated.
   150	// More fields may be populated in the future, so callers should
   151	// not DeepEqual the result in tests.
   152	//
   153	// The Response.Header is a snapshot of the headers at the time of the
   154	// first write call, or at the time of this call, if the handler never
   155	// did a write.
   156	//
   157	// The Response.Body is guaranteed to be non-nil and Body.Read call is
   158	// guaranteed to not return any error other than io.EOF.
   159	//
   160	// Result must only be called after the handler has finished running.
   161	func (rw *ResponseRecorder) Result() *http.Response {
   162		if rw.result != nil {
   163			return rw.result
   164		}
   165		if rw.snapHeader == nil {
   166			rw.snapHeader = cloneHeader(rw.HeaderMap)
   167		}
   168		res := &http.Response{
   169			Proto:      "HTTP/1.1",
   170			ProtoMajor: 1,
   171			ProtoMinor: 1,
   172			StatusCode: rw.Code,
   173			Header:     rw.snapHeader,
   174		}
   175		rw.result = res
   176		if res.StatusCode == 0 {
   177			res.StatusCode = 200
   178		}
   179		res.Status = http.StatusText(res.StatusCode)
   180		if rw.Body != nil {
   181			res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
   182		}
   183		res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
   184	
   185		if trailers, ok := rw.snapHeader["Trailer"]; ok {
   186			res.Trailer = make(http.Header, len(trailers))
   187			for _, k := range trailers {
   188				// TODO: use http2.ValidTrailerHeader, but we can't
   189				// get at it easily because it's bundled into net/http
   190				// unexported. This is good enough for now:
   191				switch k {
   192				case "Transfer-Encoding", "Content-Length", "Trailer":
   193					// Ignore since forbidden by RFC 2616 14.40.
   194					continue
   195				}
   196				k = http.CanonicalHeaderKey(k)
   197				vv, ok := rw.HeaderMap[k]
   198				if !ok {
   199					continue
   200				}
   201				vv2 := make([]string, len(vv))
   202				copy(vv2, vv)
   203				res.Trailer[k] = vv2
   204			}
   205		}
   206		for k, vv := range rw.HeaderMap {
   207			if !strings.HasPrefix(k, http.TrailerPrefix) {
   208				continue
   209			}
   210			if res.Trailer == nil {
   211				res.Trailer = make(http.Header)
   212			}
   213			for _, v := range vv {
   214				res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
   215			}
   216		}
   217		return res
   218	}
   219	
   220	// parseContentLength trims whitespace from s and returns -1 if no value
   221	// is set, or the value if it's >= 0.
   222	//
   223	// This a modified version of same function found in net/http/transfer.go. This
   224	// one just ignores an invalid header.
   225	func parseContentLength(cl string) int64 {
   226		cl = strings.TrimSpace(cl)
   227		if cl == "" {
   228			return -1
   229		}
   230		n, err := strconv.ParseInt(cl, 10, 64)
   231		if err != nil {
   232			return -1
   233		}
   234		return n
   235	}
   236	

View as plain text