...
1
2
3
4
5 package httptest
6
7 import (
8 "bytes"
9 "fmt"
10 "io/ioutil"
11 "net/http"
12 "strconv"
13 "strings"
14
15 "golang_org/x/net/http/httpguts"
16 )
17
18
19
20 type ResponseRecorder struct {
21
22
23
24
25
26
27 Code int
28
29
30
31
32
33
34
35 HeaderMap http.Header
36
37
38
39 Body *bytes.Buffer
40
41
42 Flushed bool
43
44 result *http.Response
45 snapHeader http.Header
46 wroteHeader bool
47 }
48
49
50 func NewRecorder() *ResponseRecorder {
51 return &ResponseRecorder{
52 HeaderMap: make(http.Header),
53 Body: new(bytes.Buffer),
54 Code: 200,
55 }
56 }
57
58
59
60 const DefaultRemoteAddr = "1.2.3.4"
61
62
63 func (rw *ResponseRecorder) Header() http.Header {
64 m := rw.HeaderMap
65 if m == nil {
66 m = make(http.Header)
67 rw.HeaderMap = m
68 }
69 return m
70 }
71
72
73
74
75
76
77
78
79 func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
80 if rw.wroteHeader {
81 return
82 }
83 if len(str) > 512 {
84 str = str[:512]
85 }
86
87 m := rw.Header()
88
89 _, hasType := m["Content-Type"]
90 hasTE := m.Get("Transfer-Encoding") != ""
91 if !hasType && !hasTE {
92 if b == nil {
93 b = []byte(str)
94 }
95 m.Set("Content-Type", http.DetectContentType(b))
96 }
97
98 rw.WriteHeader(200)
99 }
100
101
102 func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
103 rw.writeHeader(buf, "")
104 if rw.Body != nil {
105 rw.Body.Write(buf)
106 }
107 return len(buf), nil
108 }
109
110
111 func (rw *ResponseRecorder) WriteString(str string) (int, error) {
112 rw.writeHeader(nil, str)
113 if rw.Body != nil {
114 rw.Body.WriteString(str)
115 }
116 return len(str), nil
117 }
118
119
120
121 func (rw *ResponseRecorder) WriteHeader(code int) {
122 if rw.wroteHeader {
123 return
124 }
125 rw.Code = code
126 rw.wroteHeader = true
127 if rw.HeaderMap == nil {
128 rw.HeaderMap = make(http.Header)
129 }
130 rw.snapHeader = cloneHeader(rw.HeaderMap)
131 }
132
133 func cloneHeader(h http.Header) http.Header {
134 h2 := make(http.Header, len(h))
135 for k, vv := range h {
136 vv2 := make([]string, len(vv))
137 copy(vv2, vv)
138 h2[k] = vv2
139 }
140 return h2
141 }
142
143
144 func (rw *ResponseRecorder) Flush() {
145 if !rw.wroteHeader {
146 rw.WriteHeader(200)
147 }
148 rw.Flushed = true
149 }
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166 func (rw *ResponseRecorder) Result() *http.Response {
167 if rw.result != nil {
168 return rw.result
169 }
170 if rw.snapHeader == nil {
171 rw.snapHeader = cloneHeader(rw.HeaderMap)
172 }
173 res := &http.Response{
174 Proto: "HTTP/1.1",
175 ProtoMajor: 1,
176 ProtoMinor: 1,
177 StatusCode: rw.Code,
178 Header: rw.snapHeader,
179 }
180 rw.result = res
181 if res.StatusCode == 0 {
182 res.StatusCode = 200
183 }
184 res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
185 if rw.Body != nil {
186 res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
187 } else {
188 res.Body = http.NoBody
189 }
190 res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
191
192 if trailers, ok := rw.snapHeader["Trailer"]; ok {
193 res.Trailer = make(http.Header, len(trailers))
194 for _, k := range trailers {
195 k = http.CanonicalHeaderKey(k)
196 if !httpguts.ValidTrailerHeader(k) {
197
198 continue
199 }
200 vv, ok := rw.HeaderMap[k]
201 if !ok {
202 continue
203 }
204 vv2 := make([]string, len(vv))
205 copy(vv2, vv)
206 res.Trailer[k] = vv2
207 }
208 }
209 for k, vv := range rw.HeaderMap {
210 if !strings.HasPrefix(k, http.TrailerPrefix) {
211 continue
212 }
213 if res.Trailer == nil {
214 res.Trailer = make(http.Header)
215 }
216 for _, v := range vv {
217 res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
218 }
219 }
220 return res
221 }
222
223
224
225
226
227
228 func parseContentLength(cl string) int64 {
229 cl = strings.TrimSpace(cl)
230 if cl == "" {
231 return -1
232 }
233 n, err := strconv.ParseInt(cl, 10, 64)
234 if err != nil {
235 return -1
236 }
237 return n
238 }
239
View as plain text