...
Source file
src/net/http/npn_test.go
1
2
3
4
5 package http_test
6
7 import (
8 "bufio"
9 "bytes"
10 "crypto/tls"
11 "crypto/x509"
12 "fmt"
13 "io"
14 "io/ioutil"
15 . "net/http"
16 "net/http/httptest"
17 "strings"
18 "testing"
19 )
20
21 func TestNextProtoUpgrade(t *testing.T) {
22 setParallel(t)
23 defer afterTest(t)
24 ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
25 fmt.Fprintf(w, "path=%s,proto=", r.URL.Path)
26 if r.TLS != nil {
27 w.Write([]byte(r.TLS.NegotiatedProtocol))
28 }
29 if r.RemoteAddr == "" {
30 t.Error("request with no RemoteAddr")
31 }
32 if r.Body == nil {
33 t.Errorf("request with nil Body")
34 }
35 }))
36 ts.TLS = &tls.Config{
37 NextProtos: []string{"unhandled-proto", "tls-0.9"},
38 }
39 ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){
40 "tls-0.9": handleTLSProtocol09,
41 }
42 ts.StartTLS()
43 defer ts.Close()
44
45
46 {
47 c := ts.Client()
48 res, err := c.Get(ts.URL)
49 if err != nil {
50 t.Fatal(err)
51 }
52 body, err := ioutil.ReadAll(res.Body)
53 if err != nil {
54 t.Fatal(err)
55 }
56 if want := "path=/,proto="; string(body) != want {
57 t.Errorf("plain request = %q; want %q", body, want)
58 }
59 }
60
61
62
63 {
64 certPool := x509.NewCertPool()
65 certPool.AddCert(ts.Certificate())
66 tr := &Transport{
67 TLSClientConfig: &tls.Config{
68 RootCAs: certPool,
69 NextProtos: []string{"unhandled-proto"},
70 },
71 }
72 defer tr.CloseIdleConnections()
73 c := &Client{
74 Transport: tr,
75 }
76 res, err := c.Get(ts.URL)
77 if err == nil {
78 defer res.Body.Close()
79 var buf bytes.Buffer
80 res.Write(&buf)
81 t.Errorf("expected error on unhandled-proto request; got: %s", buf.Bytes())
82 }
83 }
84
85
86
87 {
88 c := ts.Client()
89 tlsConfig := c.Transport.(*Transport).TLSClientConfig
90 tlsConfig.NextProtos = []string{"tls-0.9"}
91 conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
92 if err != nil {
93 t.Fatal(err)
94 }
95 conn.Write([]byte("GET /foo\n"))
96 body, err := ioutil.ReadAll(conn)
97 if err != nil {
98 t.Fatal(err)
99 }
100 if want := "path=/foo,proto=tls-0.9"; string(body) != want {
101 t.Errorf("plain request = %q; want %q", body, want)
102 }
103 }
104 }
105
106
107
108 func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) {
109 br := bufio.NewReader(conn)
110 line, err := br.ReadString('\n')
111 if err != nil {
112 return
113 }
114 line = strings.TrimSpace(line)
115 path := strings.TrimPrefix(line, "GET ")
116 if path == line {
117 return
118 }
119 req, _ := NewRequest("GET", path, nil)
120 req.Proto = "HTTP/0.9"
121 req.ProtoMajor = 0
122 req.ProtoMinor = 9
123 rw := &http09Writer{conn, make(Header)}
124 h.ServeHTTP(rw, req)
125 }
126
127 type http09Writer struct {
128 io.Writer
129 h Header
130 }
131
132 func (w http09Writer) Header() Header { return w.h }
133 func (w http09Writer) WriteHeader(int) {}
134
View as plain text