1
2
3
4
5 package websocket
6
7 import (
8 "http"
9 "io"
10 "strings"
11 )
12
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
39 type Handler func(*Conn)
40
41 42 43 44
45 func getKeyNumber(s string) (r uint32) {
46
47
48
49
50 r = 0
51 for i := 0; i < len(s); i++ {
52 if s[i] >= '0' && s[i] <= '9' {
53 r = r*10 + uint32(s[i]) - '0'
54 }
55 }
56 return
57 }
58
59
60 func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
61 rwc, buf, err := w.(http.Hijacker).Hijack()
62 if err != nil {
63 panic("Hijack failed: " + err.String())
64 return
65 }
66
67
68
69 defer rwc.Close()
70
71 if req.Method != "GET" {
72 return
73 }
74
75
76 if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
77 strings.ToLower(req.Header.Get("Connection")) != "upgrade" {
78 return
79 }
80
81
82 origin := req.Header.Get("Origin")
83 if origin == "" {
84 return
85 }
86
87 key1 := req.Header.Get("Sec-Websocket-Key1")
88 if key1 == "" {
89 return
90 }
91 key2 := req.Header.Get("Sec-Websocket-Key2")
92 if key2 == "" {
93 return
94 }
95 key3 := make([]byte, 8)
96 if _, err := io.ReadFull(buf, key3); err != nil {
97 return
98 }
99
100 var location string
101 if req.TLS != nil {
102 location = "wss://" + req.Host + req.URL.RawPath
103 } else {
104 location = "ws://" + req.Host + req.URL.RawPath
105 }
106
107
108 keyNumber1 := getKeyNumber(key1)
109 keyNumber2 := getKeyNumber(key2)
110
111
112 space1 := uint32(strings.Count(key1, " "))
113 space2 := uint32(strings.Count(key2, " "))
114 if space1 == 0 || space2 == 0 {
115 return
116 }
117
118
119 if keyNumber1%space1 != 0 || keyNumber2%space2 != 0 {
120 return
121 }
122
123
124 part1 := keyNumber1 / space1
125 part2 := keyNumber2 / space2
126
127
128
129 response, err := getChallengeResponse(part1, part2, key3)
130 if err != nil {
131 return
132 }
133
134
135 buf.WriteString("HTTP/1.1 101 WebSocket Protocol Handshake\r\n")
136
137 buf.WriteString("Upgrade: WebSocket\r\n")
138 buf.WriteString("Connection: Upgrade\r\n")
139 buf.WriteString("Sec-WebSocket-Location: " + location + "\r\n")
140 buf.WriteString("Sec-WebSocket-Origin: " + origin + "\r\n")
141 protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
142 if protocol != "" {
143 buf.WriteString("Sec-WebSocket-Protocol: " + protocol + "\r\n")
144 }
145
146 buf.WriteString("\r\n")
147
148 buf.Write(response)
149 if err := buf.Flush(); err != nil {
150 return
151 }
152 ws := newConn(origin, location, protocol, buf, rwc)
153 ws.Request = req
154 f(ws)
155 }
156
157 158 159 160
161 type Draft75Handler func(*Conn)
162
163
164 func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
165 if req.Method != "GET" || req.Proto != "HTTP/1.1" {
166 w.WriteHeader(http.StatusBadRequest)
167 io.WriteString(w, "Unexpected request")
168 return
169 }
170 if req.Header.Get("Upgrade") != "WebSocket" {
171 w.WriteHeader(http.StatusBadRequest)
172 io.WriteString(w, "missing Upgrade: WebSocket header")
173 return
174 }
175 if req.Header.Get("Connection") != "Upgrade" {
176 w.WriteHeader(http.StatusBadRequest)
177 io.WriteString(w, "missing Connection: Upgrade header")
178 return
179 }
180 origin := strings.TrimSpace(req.Header.Get("Origin"))
181 if origin == "" {
182 w.WriteHeader(http.StatusBadRequest)
183 io.WriteString(w, "missing Origin header")
184 return
185 }
186
187 rwc, buf, err := w.(http.Hijacker).Hijack()
188 if err != nil {
189 panic("Hijack failed: " + err.String())
190 return
191 }
192 defer rwc.Close()
193
194 var location string
195 if req.TLS != nil {
196 location = "wss://" + req.Host + req.URL.RawPath
197 } else {
198 location = "ws://" + req.Host + req.URL.RawPath
199 }
200
201
202
203 buf.WriteString("HTTP/1.1 101 Web Socket Protocol Handshake\r\n")
204 buf.WriteString("Upgrade: WebSocket\r\n")
205 buf.WriteString("Connection: Upgrade\r\n")
206 buf.WriteString("WebSocket-Origin: " + origin + "\r\n")
207 buf.WriteString("WebSocket-Location: " + location + "\r\n")
208 protocol := strings.TrimSpace(req.Header.Get("Websocket-Protocol"))
209
210 if protocol != "" {
211 buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n")
212 }
213 buf.WriteString("\r\n")
214 if err := buf.Flush(); err != nil {
215 return
216 }
217 ws := newConn(origin, location, protocol, buf, rwc)
218 f(ws)
219 }