commit a3156aaa121446c4136927f8c2139fefe05ba82c Author: Brad Fitzpatrick Date: Tue Sep 29 14:26:48 2015 -0700 net/http/httptest: change Server to use http.Server.ConnState for accounting With this CL, httptest.Server now uses connection-level accounting of outstanding requests instead of ServeHTTP-level accounting. This is more robust and results in a non-racy shutdown. This is much easier now that net/http.Server has the ConnState hook. Fixes #12789 Fixes #12781 Change-Id: I098cf334a6494316acb66cd07df90766df41764b Reviewed-on: https://go-review.googlesource.com/15151 Reviewed-by: Andrew Gerrand Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go index 96eb0ef..e4f680f 100644 --- a/src/net/http/httptest/server.go +++ b/src/net/http/httptest/server.go @@ -7,13 +7,17 @@ package httptest import ( + "bytes" "crypto/tls" "flag" "fmt" + "log" "net" "net/http" "os" + "runtime" "sync" + "time" ) // A Server is an HTTP server listening on a system-chosen port on the @@ -34,24 +38,10 @@ type Server struct { // wg counts the number of outstanding HTTP requests on this server. // Close blocks until all requests are finished. wg sync.WaitGroup -} - -// historyListener keeps track of all connections that it's ever -// accepted. -type historyListener struct { - net.Listener - sync.Mutex // protects history - history []net.Conn -} -func (hs *historyListener) Accept() (c net.Conn, err error) { - c, err = hs.Listener.Accept() - if err == nil { - hs.Lock() - hs.history = append(hs.history, c) - hs.Unlock() - } - return + mu sync.Mutex // guards closed and conns + closed bool + conns map[net.Conn]http.ConnState // except terminal states } func newLocalListener() net.Listener { @@ -103,10 +93,9 @@ func (s *Server) Start() { if s.URL != "" { panic("Server already started") } - s.Listener = &historyListener{Listener: s.Listener} s.URL = "http://" + s.Listener.Addr().String() - s.wrapHandler() - go s.Config.Serve(s.Listener) + s.wrap() + s.goServe() if *serve != "" { fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL) select {} @@ -134,23 +123,10 @@ func (s *Server) StartTLS() { if len(s.TLS.Certificates) == 0 { s.TLS.Certificates = []tls.Certificate{cert} } - tlsListener := tls.NewListener(s.Listener, s.TLS) - - s.Listener = &historyListener{Listener: tlsListener} + s.Listener = tls.NewListener(s.Listener, s.TLS) s.URL = "https://" + s.Listener.Addr().String() - s.wrapHandler() - go s.Config.Serve(s.Listener) -} - -func (s *Server) wrapHandler() { - h := s.Config.Handler - if h == nil { - h = http.DefaultServeMux - } - s.Config.Handler = &waitGroupHandler{ - s: s, - h: h, - } + s.wrap() + s.goServe() } // NewTLSServer starts and returns a new Server using TLS. @@ -161,43 +137,139 @@ func NewTLSServer(handler http.Handler) *Server { return ts } +type closeIdleTransport interface { + CloseIdleConnections() +} + // Close shuts down the server and blocks until all outstanding // requests on this server have completed. func (s *Server) Close() { - s.Listener.Close() - s.wg.Wait() - s.CloseClientConnections() - if t, ok := http.DefaultTransport.(*http.Transport); ok { + s.mu.Lock() + if !s.closed { + s.closed = true + s.Listener.Close() + s.Config.SetKeepAlivesEnabled(false) + for c, st := range s.conns { + if st == http.StateIdle { + s.closeConn(c) + } + } + // If this server doesn't shut down in 5 seconds, tell the user why. + t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo) + defer t.Stop() + } + s.mu.Unlock() + + // Not part of httptest.Server's correctness, but assume most + // users of httptest.Server will be using the standard + // transport, so help them out and close any idle connections for them. + if t, ok := http.DefaultTransport.(closeIdleTransport); ok { t.CloseIdleConnections() } + + s.wg.Wait() } -// CloseClientConnections closes any currently open HTTP connections -// to the test Server. +func (s *Server) logCloseHangDebugInfo() { + s.mu.Lock() + defer s.mu.Unlock() + var buf bytes.Buffer + buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n") + for c, st := range s.conns { + fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st) + } + log.Print(buf.String()) +} + +// CloseClientConnections closes any open HTTP connections to the test Server. func (s *Server) CloseClientConnections() { - hl, ok := s.Listener.(*historyListener) - if !ok { - return + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.conns { + s.closeConn(c) } - hl.Lock() - for _, conn := range hl.history { - conn.Close() +} + +func (s *Server) goServe() { + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.Config.Serve(s.Listener) + }() +} + +// wrap installs the connection state-tracking hook to know which +// connections are idle. +func (s *Server) wrap() { + oldHook := s.Config.ConnState + s.Config.ConnState = func(c net.Conn, cs http.ConnState) { + s.mu.Lock() + defer s.mu.Unlock() + switch cs { + case http.StateNew: + s.wg.Add(1) + if _, exists := s.conns[c]; exists { + panic("invalid state transition") + } + if s.conns == nil { + s.conns = make(map[net.Conn]http.ConnState) + } + s.conns[c] = cs + if s.closed { + // Probably just a socket-late-binding dial from + // the default transport that lost the race (and + // thus this connection is now idle and will + // never be used). + s.closeConn(c) + } + case http.StateActive: + if oldState, ok := s.conns[c]; ok { + if oldState != http.StateNew && oldState != http.StateIdle { + panic("invalid state transition") + } + s.conns[c] = cs + } + case http.StateIdle: + if oldState, ok := s.conns[c]; ok { + if oldState != http.StateActive { + panic("invalid state transition") + } + s.conns[c] = cs + } + if s.closed { + s.closeConn(c) + } + case http.StateHijacked, http.StateClosed: + s.forgetConn(c) + } + if oldHook != nil { + oldHook(c, cs) + } } - hl.Unlock() } -// waitGroupHandler wraps a handler, incrementing and decrementing a -// sync.WaitGroup on each request, to enable Server.Close to block -// until outstanding requests are finished. -type waitGroupHandler struct { - s *Server - h http.Handler // non-nil +// closeConn closes c. Except on plan9, which is special. See comment below. +// s.mu must be held. +func (s *Server) closeConn(c net.Conn) { + if runtime.GOOS == "plan9" { + // Go's Plan 9 net package isn't great at unblocking reads when + // their underlying TCP connections are closed. Don't trust + // that that the ConnState state machine will get to + // StateClosed. Instead, just go there directly. Plan 9 may leak + // resources if the syscall doesn't end up returning. Oh well. + s.forgetConn(c) + } + go c.Close() } -func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - h.s.wg.Add(1) - defer h.s.wg.Done() // a defer, in case ServeHTTP below panics - h.h.ServeHTTP(w, r) +// forgetConn removes c from the set of tracked conns and decrements it from the +// waitgroup, unless it was previously removed. +// s.mu must be held. +func (s *Server) forgetConn(c net.Conn) { + if _, ok := s.conns[c]; ok { + delete(s.conns, c) + s.wg.Done() + } } // localhostCert is a PEM-encoded TLS cert with SAN IPs diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go index 500a9f0..90901ce 100644 --- a/src/net/http/httptest/server_test.go +++ b/src/net/http/httptest/server_test.go @@ -27,3 +27,30 @@ func TestServer(t *testing.T) { t.Errorf("got %q, want hello", string(got)) } } + +// Issue 12781 +func TestGetAfterClose(t *testing.T) { + ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + + res, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + got, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Fatalf("got %q, want hello", string(got)) + } + + ts.Close() + + res, err = http.Get(ts.URL) + if err == nil { + body, _ := ioutil.ReadAll(res.Body) + t.Fatalf("Unexected response after close: %v, %v, %s", res.Status, res.Header, body) + } +}