...
Run Format

Source file src/net/http/httptest/server.go

Documentation: net/http/httptest

  // Copyright 2011 The Go Authors. All rights reserved.
  // Use of this source code is governed by a BSD-style
  // license that can be found in the LICENSE file.
  
  // Implementation of Server
  
  package httptest
  
  import (
  	"bytes"
  	"crypto/tls"
  	"crypto/x509"
  	"flag"
  	"fmt"
  	"log"
  	"net"
  	"net/http"
  	"net/http/internal"
  	"os"
  	"sync"
  	"time"
  )
  
  // A Server is an HTTP server listening on a system-chosen port on the
  // local loopback interface, for use in end-to-end HTTP tests.
  type Server struct {
  	URL      string // base URL of form http://ipaddr:port with no trailing slash
  	Listener net.Listener
  
  	// TLS is the optional TLS configuration, populated with a new config
  	// after TLS is started. If set on an unstarted server before StartTLS
  	// is called, existing fields are copied into the new config.
  	TLS *tls.Config
  
  	// Config may be changed after calling NewUnstartedServer and
  	// before Start or StartTLS.
  	Config *http.Server
  
  	// certificate is a parsed version of the TLS config certificate, if present.
  	certificate *x509.Certificate
  
  	// wg counts the number of outstanding HTTP requests on this server.
  	// Close blocks until all requests are finished.
  	wg sync.WaitGroup
  
  	mu     sync.Mutex // guards closed and conns
  	closed bool
  	conns  map[net.Conn]http.ConnState // except terminal states
  
  	// client is configured for use with the server.
  	// Its transport is automatically closed when Close is called.
  	client *http.Client
  }
  
  func newLocalListener() net.Listener {
  	if *serve != "" {
  		l, err := net.Listen("tcp", *serve)
  		if err != nil {
  			panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err))
  		}
  		return l
  	}
  	l, err := net.Listen("tcp", "127.0.0.1:0")
  	if err != nil {
  		if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
  			panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
  		}
  	}
  	return l
  }
  
  // When debugging a particular http server-based test,
  // this flag lets you run
  //	go test -run=BrokenTest -httptest.serve=127.0.0.1:8000
  // to start the broken server so you can interact with it manually.
  var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks")
  
  // NewServer starts and returns a new Server.
  // The caller should call Close when finished, to shut it down.
  func NewServer(handler http.Handler) *Server {
  	ts := NewUnstartedServer(handler)
  	ts.Start()
  	return ts
  }
  
  // NewUnstartedServer returns a new Server but doesn't start it.
  //
  // After changing its configuration, the caller should call Start or
  // StartTLS.
  //
  // The caller should call Close when finished, to shut it down.
  func NewUnstartedServer(handler http.Handler) *Server {
  	return &Server{
  		Listener: newLocalListener(),
  		Config:   &http.Server{Handler: handler},
  	}
  }
  
  // Start starts a server from NewUnstartedServer.
  func (s *Server) Start() {
  	if s.URL != "" {
  		panic("Server already started")
  	}
  	if s.client == nil {
  		s.client = &http.Client{Transport: &http.Transport{}}
  	}
  	s.URL = "http://" + s.Listener.Addr().String()
  	s.wrap()
  	s.goServe()
  	if *serve != "" {
  		fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
  		select {}
  	}
  }
  
  // StartTLS starts TLS on a server from NewUnstartedServer.
  func (s *Server) StartTLS() {
  	if s.URL != "" {
  		panic("Server already started")
  	}
  	if s.client == nil {
  		s.client = &http.Client{Transport: &http.Transport{}}
  	}
  	cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
  	if err != nil {
  		panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
  	}
  
  	existingConfig := s.TLS
  	if existingConfig != nil {
  		s.TLS = existingConfig.Clone()
  	} else {
  		s.TLS = new(tls.Config)
  	}
  	if s.TLS.NextProtos == nil {
  		s.TLS.NextProtos = []string{"http/1.1"}
  	}
  	if len(s.TLS.Certificates) == 0 {
  		s.TLS.Certificates = []tls.Certificate{cert}
  	}
  	s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
  	if err != nil {
  		panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
  	}
  	certpool := x509.NewCertPool()
  	certpool.AddCert(s.certificate)
  	s.client.Transport = &http.Transport{
  		TLSClientConfig: &tls.Config{
  			RootCAs: certpool,
  		},
  	}
  	s.Listener = tls.NewListener(s.Listener, s.TLS)
  	s.URL = "https://" + s.Listener.Addr().String()
  	s.wrap()
  	s.goServe()
  }
  
  // NewTLSServer starts and returns a new Server using TLS.
  // The caller should call Close when finished, to shut it down.
  func NewTLSServer(handler http.Handler) *Server {
  	ts := NewUnstartedServer(handler)
  	ts.StartTLS()
  	return ts
  }
  
  type closeIdleTransport interface {
  	CloseIdleConnections()
  }
  
  // Close shuts down the server and blocks until all outstanding
  // requests on this server have completed.
  func (s *Server) Close() {
  	s.mu.Lock()
  	if !s.closed {
  		s.closed = true
  		s.Listener.Close()
  		s.Config.SetKeepAlivesEnabled(false)
  		for c, st := range s.conns {
  			// Force-close any idle connections (those between
  			// requests) and new connections (those which connected
  			// but never sent a request). StateNew connections are
  			// super rare and have only been seen (in
  			// previously-flaky tests) in the case of
  			// socket-late-binding races from the http Client
  			// dialing this server and then getting an idle
  			// connection before the dial completed. There is thus
  			// a connected connection in StateNew with no
  			// associated Request. We only close StateIdle and
  			// StateNew because they're not doing anything. It's
  			// possible StateNew is about to do something in a few
  			// milliseconds, but a previous CL to check again in a
  			// few milliseconds wasn't liked (early versions of
  			// https://golang.org/cl/15151) so now we just
  			// forcefully close StateNew. The docs for Server.Close say
  			// we wait for "outstanding requests", so we don't close things
  			// in StateActive.
  			if st == http.StateIdle || st == http.StateNew {
  				s.closeConn(c)
  			}
  		}
  		// If this server doesn't shut down in 5 seconds, tell the user why.
  		t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
  		defer t.Stop()
  	}
  	s.mu.Unlock()
  
  	// Not part of httptest.Server's correctness, but assume most
  	// users of httptest.Server will be using the standard
  	// transport, so help them out and close any idle connections for them.
  	if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
  		t.CloseIdleConnections()
  	}
  
  	// Also close the client idle connections.
  	if s.client != nil {
  		if t, ok := s.client.Transport.(closeIdleTransport); ok {
  			t.CloseIdleConnections()
  		}
  	}
  
  	s.wg.Wait()
  }
  
  func (s *Server) logCloseHangDebugInfo() {
  	s.mu.Lock()
  	defer s.mu.Unlock()
  	var buf bytes.Buffer
  	buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
  	for c, st := range s.conns {
  		fmt.Fprintf(&buf, "  %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
  	}
  	log.Print(buf.String())
  }
  
  // CloseClientConnections closes any open HTTP connections to the test Server.
  func (s *Server) CloseClientConnections() {
  	s.mu.Lock()
  	nconn := len(s.conns)
  	ch := make(chan struct{}, nconn)
  	for c := range s.conns {
  		go s.closeConnChan(c, ch)
  	}
  	s.mu.Unlock()
  
  	// Wait for outstanding closes to finish.
  	//
  	// Out of paranoia for making a late change in Go 1.6, we
  	// bound how long this can wait, since golang.org/issue/14291
  	// isn't fully understood yet. At least this should only be used
  	// in tests.
  	timer := time.NewTimer(5 * time.Second)
  	defer timer.Stop()
  	for i := 0; i < nconn; i++ {
  		select {
  		case <-ch:
  		case <-timer.C:
  			// Too slow. Give up.
  			return
  		}
  	}
  }
  
  // Certificate returns the certificate used by the server, or nil if
  // the server doesn't use TLS.
  func (s *Server) Certificate() *x509.Certificate {
  	return s.certificate
  }
  
  // Client returns an HTTP client configured for making requests to the server.
  // It is configured to trust the server's TLS test certificate and will
  // close its idle connections on Server.Close.
  func (s *Server) Client() *http.Client {
  	return s.client
  }
  
  func (s *Server) goServe() {
  	s.wg.Add(1)
  	go func() {
  		defer s.wg.Done()
  		s.Config.Serve(s.Listener)
  	}()
  }
  
  // wrap installs the connection state-tracking hook to know which
  // connections are idle.
  func (s *Server) wrap() {
  	oldHook := s.Config.ConnState
  	s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
  		s.mu.Lock()
  		defer s.mu.Unlock()
  		switch cs {
  		case http.StateNew:
  			s.wg.Add(1)
  			if _, exists := s.conns[c]; exists {
  				panic("invalid state transition")
  			}
  			if s.conns == nil {
  				s.conns = make(map[net.Conn]http.ConnState)
  			}
  			s.conns[c] = cs
  			if s.closed {
  				// Probably just a socket-late-binding dial from
  				// the default transport that lost the race (and
  				// thus this connection is now idle and will
  				// never be used).
  				s.closeConn(c)
  			}
  		case http.StateActive:
  			if oldState, ok := s.conns[c]; ok {
  				if oldState != http.StateNew && oldState != http.StateIdle {
  					panic("invalid state transition")
  				}
  				s.conns[c] = cs
  			}
  		case http.StateIdle:
  			if oldState, ok := s.conns[c]; ok {
  				if oldState != http.StateActive {
  					panic("invalid state transition")
  				}
  				s.conns[c] = cs
  			}
  			if s.closed {
  				s.closeConn(c)
  			}
  		case http.StateHijacked, http.StateClosed:
  			s.forgetConn(c)
  		}
  		if oldHook != nil {
  			oldHook(c, cs)
  		}
  	}
  }
  
  // closeConn closes c.
  // s.mu must be held.
  func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
  
  // closeConnChan is like closeConn, but takes an optional channel to receive a value
  // when the goroutine closing c is done.
  func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
  	c.Close()
  	if done != nil {
  		done <- struct{}{}
  	}
  }
  
  // forgetConn removes c from the set of tracked conns and decrements it from the
  // waitgroup, unless it was previously removed.
  // s.mu must be held.
  func (s *Server) forgetConn(c net.Conn) {
  	if _, ok := s.conns[c]; ok {
  		delete(s.conns, c)
  		s.wg.Done()
  	}
  }
  

View as plain text