...
Run Format

Source file src/net/http/httptest/recorder.go

  // Copyright 2011 The Go Authors. All rights reserved.
  // Use of this source code is governed by a BSD-style
  // license that can be found in the LICENSE file.
  
  package httptest
  
  import (
  	"bytes"
  	"io/ioutil"
  	"net/http"
  	"strconv"
  	"strings"
  )
  
  // ResponseRecorder is an implementation of http.ResponseWriter that
  // records its mutations for later inspection in tests.
  type ResponseRecorder struct {
  	// Code is the HTTP response code set by WriteHeader.
  	//
  	// Note that if a Handler never calls WriteHeader or Write,
  	// this might end up being 0, rather than the implicit
  	// http.StatusOK. To get the implicit value, use the Result
  	// method.
  	Code int
  
  	// HeaderMap contains the headers explicitly set by the Handler.
  	//
  	// To get the implicit headers set by the server (such as
  	// automatic Content-Type), use the Result method.
  	HeaderMap http.Header
  
  	// Body is the buffer to which the Handler's Write calls are sent.
  	// If nil, the Writes are silently discarded.
  	Body *bytes.Buffer
  
  	// Flushed is whether the Handler called Flush.
  	Flushed bool
  
  	result      *http.Response // cache of Result's return value
  	snapHeader  http.Header    // snapshot of HeaderMap at first Write
  	wroteHeader bool
  }
  
  // NewRecorder returns an initialized ResponseRecorder.
  func NewRecorder() *ResponseRecorder {
  	return &ResponseRecorder{
  		HeaderMap: make(http.Header),
  		Body:      new(bytes.Buffer),
  		Code:      200,
  	}
  }
  
  // DefaultRemoteAddr is the default remote address to return in RemoteAddr if
  // an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
  const DefaultRemoteAddr = "1.2.3.4"
  
  // Header returns the response headers.
  func (rw *ResponseRecorder) Header() http.Header {
  	m := rw.HeaderMap
  	if m == nil {
  		m = make(http.Header)
  		rw.HeaderMap = m
  	}
  	return m
  }
  
  // writeHeader writes a header if it was not written yet and
  // detects Content-Type if needed.
  //
  // bytes or str are the beginning of the response body.
  // We pass both to avoid unnecessarily generate garbage
  // in rw.WriteString which was created for performance reasons.
  // Non-nil bytes win.
  func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
  	if rw.wroteHeader {
  		return
  	}
  	if len(str) > 512 {
  		str = str[:512]
  	}
  
  	m := rw.Header()
  
  	_, hasType := m["Content-Type"]
  	hasTE := m.Get("Transfer-Encoding") != ""
  	if !hasType && !hasTE {
  		if b == nil {
  			b = []byte(str)
  		}
  		m.Set("Content-Type", http.DetectContentType(b))
  	}
  
  	rw.WriteHeader(200)
  }
  
  // Write always succeeds and writes to rw.Body, if not nil.
  func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
  	rw.writeHeader(buf, "")
  	if rw.Body != nil {
  		rw.Body.Write(buf)
  	}
  	return len(buf), nil
  }
  
  // WriteString always succeeds and writes to rw.Body, if not nil.
  func (rw *ResponseRecorder) WriteString(str string) (int, error) {
  	rw.writeHeader(nil, str)
  	if rw.Body != nil {
  		rw.Body.WriteString(str)
  	}
  	return len(str), nil
  }
  
  // WriteHeader sets rw.Code. After it is called, changing rw.Header
  // will not affect rw.HeaderMap.
  func (rw *ResponseRecorder) WriteHeader(code int) {
  	if rw.wroteHeader {
  		return
  	}
  	rw.Code = code
  	rw.wroteHeader = true
  	if rw.HeaderMap == nil {
  		rw.HeaderMap = make(http.Header)
  	}
  	rw.snapHeader = cloneHeader(rw.HeaderMap)
  }
  
  func cloneHeader(h http.Header) http.Header {
  	h2 := make(http.Header, len(h))
  	for k, vv := range h {
  		vv2 := make([]string, len(vv))
  		copy(vv2, vv)
  		h2[k] = vv2
  	}
  	return h2
  }
  
  // Flush sets rw.Flushed to true.
  func (rw *ResponseRecorder) Flush() {
  	if !rw.wroteHeader {
  		rw.WriteHeader(200)
  	}
  	rw.Flushed = true
  }
  
  // Result returns the response generated by the handler.
  //
  // The returned Response will have at least its StatusCode,
  // Header, Body, and optionally Trailer populated.
  // More fields may be populated in the future, so callers should
  // not DeepEqual the result in tests.
  //
  // The Response.Header is a snapshot of the headers at the time of the
  // first write call, or at the time of this call, if the handler never
  // did a write.
  //
  // The Response.Body is guaranteed to be non-nil and Body.Read call is
  // guaranteed to not return any error other than io.EOF.
  //
  // Result must only be called after the handler has finished running.
  func (rw *ResponseRecorder) Result() *http.Response {
  	if rw.result != nil {
  		return rw.result
  	}
  	if rw.snapHeader == nil {
  		rw.snapHeader = cloneHeader(rw.HeaderMap)
  	}
  	res := &http.Response{
  		Proto:      "HTTP/1.1",
  		ProtoMajor: 1,
  		ProtoMinor: 1,
  		StatusCode: rw.Code,
  		Header:     rw.snapHeader,
  	}
  	rw.result = res
  	if res.StatusCode == 0 {
  		res.StatusCode = 200
  	}
  	res.Status = http.StatusText(res.StatusCode)
  	if rw.Body != nil {
  		res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
  	}
  	res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
  
  	if trailers, ok := rw.snapHeader["Trailer"]; ok {
  		res.Trailer = make(http.Header, len(trailers))
  		for _, k := range trailers {
  			// TODO: use http2.ValidTrailerHeader, but we can't
  			// get at it easily because it's bundled into net/http
  			// unexported. This is good enough for now:
  			switch k {
  			case "Transfer-Encoding", "Content-Length", "Trailer":
  				// Ignore since forbidden by RFC 2616 14.40.
  				continue
  			}
  			k = http.CanonicalHeaderKey(k)
  			vv, ok := rw.HeaderMap[k]
  			if !ok {
  				continue
  			}
  			vv2 := make([]string, len(vv))
  			copy(vv2, vv)
  			res.Trailer[k] = vv2
  		}
  	}
  	for k, vv := range rw.HeaderMap {
  		if !strings.HasPrefix(k, http.TrailerPrefix) {
  			continue
  		}
  		if res.Trailer == nil {
  			res.Trailer = make(http.Header)
  		}
  		for _, v := range vv {
  			res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
  		}
  	}
  	return res
  }
  
  // parseContentLength trims whitespace from s and returns -1 if no value
  // is set, or the value if it's >= 0.
  //
  // This a modified version of same function found in net/http/transfer.go. This
  // one just ignores an invalid header.
  func parseContentLength(cl string) int64 {
  	cl = strings.TrimSpace(cl)
  	if cl == "" {
  		return -1
  	}
  	n, err := strconv.ParseInt(cl, 10, 64)
  	if err != nil {
  		return -1
  	}
  	return n
  }
  

View as plain text