1
2
3
4
5 package httptest
6
7 import (
8 "bytes"
9 "fmt"
10 "io"
11 "net/http"
12 "net/textproto"
13 "strconv"
14 "strings"
15
16 "golang.org/x/net/http/httpguts"
17 )
18
19
20
21 type ResponseRecorder struct {
22
23
24
25
26
27
28 Code int
29
30
31
32
33
34
35
36 HeaderMap http.Header
37
38
39
40 Body *bytes.Buffer
41
42
43 Flushed bool
44
45 result *http.Response
46 snapHeader http.Header
47 wroteHeader bool
48 }
49
50
51 func NewRecorder() *ResponseRecorder {
52 return &ResponseRecorder{
53 HeaderMap: make(http.Header),
54 Body: new(bytes.Buffer),
55 Code: 200,
56 }
57 }
58
59
60
61 const DefaultRemoteAddr = "1.2.3.4"
62
63
64
65
66
67 func (rw *ResponseRecorder) Header() http.Header {
68 m := rw.HeaderMap
69 if m == nil {
70 m = make(http.Header)
71 rw.HeaderMap = m
72 }
73 return m
74 }
75
76
77
78
79
80
81
82
83 func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
84 if rw.wroteHeader {
85 return
86 }
87 if len(str) > 512 {
88 str = str[:512]
89 }
90
91 m := rw.Header()
92
93 _, hasType := m["Content-Type"]
94 hasTE := m.Get("Transfer-Encoding") != ""
95 if !hasType && !hasTE {
96 if b == nil {
97 b = []byte(str)
98 }
99 m.Set("Content-Type", http.DetectContentType(b))
100 }
101
102 rw.WriteHeader(200)
103 }
104
105
106
107 func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
108 rw.writeHeader(buf, "")
109 if rw.Body != nil {
110 rw.Body.Write(buf)
111 }
112 return len(buf), nil
113 }
114
115
116
117 func (rw *ResponseRecorder) WriteString(str string) (int, error) {
118 rw.writeHeader(nil, str)
119 if rw.Body != nil {
120 rw.Body.WriteString(str)
121 }
122 return len(str), nil
123 }
124
125
126 func (rw *ResponseRecorder) WriteHeader(code int) {
127 if rw.wroteHeader {
128 return
129 }
130 rw.Code = code
131 rw.wroteHeader = true
132 if rw.HeaderMap == nil {
133 rw.HeaderMap = make(http.Header)
134 }
135 rw.snapHeader = rw.HeaderMap.Clone()
136 }
137
138
139
140 func (rw *ResponseRecorder) Flush() {
141 if !rw.wroteHeader {
142 rw.WriteHeader(200)
143 }
144 rw.Flushed = true
145 }
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162 func (rw *ResponseRecorder) Result() *http.Response {
163 if rw.result != nil {
164 return rw.result
165 }
166 if rw.snapHeader == nil {
167 rw.snapHeader = rw.HeaderMap.Clone()
168 }
169 res := &http.Response{
170 Proto: "HTTP/1.1",
171 ProtoMajor: 1,
172 ProtoMinor: 1,
173 StatusCode: rw.Code,
174 Header: rw.snapHeader,
175 }
176 rw.result = res
177 if res.StatusCode == 0 {
178 res.StatusCode = 200
179 }
180 res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
181 if rw.Body != nil {
182 res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
183 } else {
184 res.Body = http.NoBody
185 }
186 res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
187
188 if trailers, ok := rw.snapHeader["Trailer"]; ok {
189 res.Trailer = make(http.Header, len(trailers))
190 for _, k := range trailers {
191 k = http.CanonicalHeaderKey(k)
192 if !httpguts.ValidTrailerHeader(k) {
193
194 continue
195 }
196 vv, ok := rw.HeaderMap[k]
197 if !ok {
198 continue
199 }
200 vv2 := make([]string, len(vv))
201 copy(vv2, vv)
202 res.Trailer[k] = vv2
203 }
204 }
205 for k, vv := range rw.HeaderMap {
206 if !strings.HasPrefix(k, http.TrailerPrefix) {
207 continue
208 }
209 if res.Trailer == nil {
210 res.Trailer = make(http.Header)
211 }
212 for _, v := range vv {
213 res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
214 }
215 }
216 return res
217 }
218
219
220
221
222
223
224 func parseContentLength(cl string) int64 {
225 cl = textproto.TrimString(cl)
226 if cl == "" {
227 return -1
228 }
229 n, err := strconv.ParseUint(cl, 10, 63)
230 if err != nil {
231 return -1
232 }
233 return int64(n)
234 }
235
View as plain text