golang/golang-1.5.1-a3156aaa12.patch
2015-10-19 20:32:59 -04:00

310 lines
8.3 KiB
Diff

commit a3156aaa121446c4136927f8c2139fefe05ba82c
Author: Brad Fitzpatrick <bradfitz@golang.org>
Date: Tue Sep 29 14:26:48 2015 -0700
net/http/httptest: change Server to use http.Server.ConnState for accounting
With this CL, httptest.Server now uses connection-level accounting of
outstanding requests instead of ServeHTTP-level accounting. This is
more robust and results in a non-racy shutdown.
This is much easier now that net/http.Server has the ConnState hook.
Fixes #12789
Fixes #12781
Change-Id: I098cf334a6494316acb66cd07df90766df41764b
Reviewed-on: https://go-review.googlesource.com/15151
Reviewed-by: Andrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go
index 96eb0ef..e4f680f 100644
--- a/src/net/http/httptest/server.go
+++ b/src/net/http/httptest/server.go
@@ -7,13 +7,17 @@
package httptest
import (
+ "bytes"
"crypto/tls"
"flag"
"fmt"
+ "log"
"net"
"net/http"
"os"
+ "runtime"
"sync"
+ "time"
)
// A Server is an HTTP server listening on a system-chosen port on the
@@ -34,24 +38,10 @@ type Server struct {
// wg counts the number of outstanding HTTP requests on this server.
// Close blocks until all requests are finished.
wg sync.WaitGroup
-}
-
-// historyListener keeps track of all connections that it's ever
-// accepted.
-type historyListener struct {
- net.Listener
- sync.Mutex // protects history
- history []net.Conn
-}
-func (hs *historyListener) Accept() (c net.Conn, err error) {
- c, err = hs.Listener.Accept()
- if err == nil {
- hs.Lock()
- hs.history = append(hs.history, c)
- hs.Unlock()
- }
- return
+ mu sync.Mutex // guards closed and conns
+ closed bool
+ conns map[net.Conn]http.ConnState // except terminal states
}
func newLocalListener() net.Listener {
@@ -103,10 +93,9 @@ func (s *Server) Start() {
if s.URL != "" {
panic("Server already started")
}
- s.Listener = &historyListener{Listener: s.Listener}
s.URL = "http://" + s.Listener.Addr().String()
- s.wrapHandler()
- go s.Config.Serve(s.Listener)
+ s.wrap()
+ s.goServe()
if *serve != "" {
fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
select {}
@@ -134,23 +123,10 @@ func (s *Server) StartTLS() {
if len(s.TLS.Certificates) == 0 {
s.TLS.Certificates = []tls.Certificate{cert}
}
- tlsListener := tls.NewListener(s.Listener, s.TLS)
-
- s.Listener = &historyListener{Listener: tlsListener}
+ s.Listener = tls.NewListener(s.Listener, s.TLS)
s.URL = "https://" + s.Listener.Addr().String()
- s.wrapHandler()
- go s.Config.Serve(s.Listener)
-}
-
-func (s *Server) wrapHandler() {
- h := s.Config.Handler
- if h == nil {
- h = http.DefaultServeMux
- }
- s.Config.Handler = &waitGroupHandler{
- s: s,
- h: h,
- }
+ s.wrap()
+ s.goServe()
}
// NewTLSServer starts and returns a new Server using TLS.
@@ -161,43 +137,139 @@ func NewTLSServer(handler http.Handler) *Server {
return ts
}
+type closeIdleTransport interface {
+ CloseIdleConnections()
+}
+
// Close shuts down the server and blocks until all outstanding
// requests on this server have completed.
func (s *Server) Close() {
- s.Listener.Close()
- s.wg.Wait()
- s.CloseClientConnections()
- if t, ok := http.DefaultTransport.(*http.Transport); ok {
+ s.mu.Lock()
+ if !s.closed {
+ s.closed = true
+ s.Listener.Close()
+ s.Config.SetKeepAlivesEnabled(false)
+ for c, st := range s.conns {
+ if st == http.StateIdle {
+ s.closeConn(c)
+ }
+ }
+ // If this server doesn't shut down in 5 seconds, tell the user why.
+ t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
+ defer t.Stop()
+ }
+ s.mu.Unlock()
+
+ // Not part of httptest.Server's correctness, but assume most
+ // users of httptest.Server will be using the standard
+ // transport, so help them out and close any idle connections for them.
+ if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
t.CloseIdleConnections()
}
+
+ s.wg.Wait()
}
-// CloseClientConnections closes any currently open HTTP connections
-// to the test Server.
+func (s *Server) logCloseHangDebugInfo() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ var buf bytes.Buffer
+ buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
+ for c, st := range s.conns {
+ fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
+ }
+ log.Print(buf.String())
+}
+
+// CloseClientConnections closes any open HTTP connections to the test Server.
func (s *Server) CloseClientConnections() {
- hl, ok := s.Listener.(*historyListener)
- if !ok {
- return
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for c := range s.conns {
+ s.closeConn(c)
}
- hl.Lock()
- for _, conn := range hl.history {
- conn.Close()
+}
+
+func (s *Server) goServe() {
+ s.wg.Add(1)
+ go func() {
+ defer s.wg.Done()
+ s.Config.Serve(s.Listener)
+ }()
+}
+
+// wrap installs the connection state-tracking hook to know which
+// connections are idle.
+func (s *Server) wrap() {
+ oldHook := s.Config.ConnState
+ s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ switch cs {
+ case http.StateNew:
+ s.wg.Add(1)
+ if _, exists := s.conns[c]; exists {
+ panic("invalid state transition")
+ }
+ if s.conns == nil {
+ s.conns = make(map[net.Conn]http.ConnState)
+ }
+ s.conns[c] = cs
+ if s.closed {
+ // Probably just a socket-late-binding dial from
+ // the default transport that lost the race (and
+ // thus this connection is now idle and will
+ // never be used).
+ s.closeConn(c)
+ }
+ case http.StateActive:
+ if oldState, ok := s.conns[c]; ok {
+ if oldState != http.StateNew && oldState != http.StateIdle {
+ panic("invalid state transition")
+ }
+ s.conns[c] = cs
+ }
+ case http.StateIdle:
+ if oldState, ok := s.conns[c]; ok {
+ if oldState != http.StateActive {
+ panic("invalid state transition")
+ }
+ s.conns[c] = cs
+ }
+ if s.closed {
+ s.closeConn(c)
+ }
+ case http.StateHijacked, http.StateClosed:
+ s.forgetConn(c)
+ }
+ if oldHook != nil {
+ oldHook(c, cs)
+ }
}
- hl.Unlock()
}
-// waitGroupHandler wraps a handler, incrementing and decrementing a
-// sync.WaitGroup on each request, to enable Server.Close to block
-// until outstanding requests are finished.
-type waitGroupHandler struct {
- s *Server
- h http.Handler // non-nil
+// closeConn closes c. Except on plan9, which is special. See comment below.
+// s.mu must be held.
+func (s *Server) closeConn(c net.Conn) {
+ if runtime.GOOS == "plan9" {
+ // Go's Plan 9 net package isn't great at unblocking reads when
+ // their underlying TCP connections are closed. Don't trust
+ // that that the ConnState state machine will get to
+ // StateClosed. Instead, just go there directly. Plan 9 may leak
+ // resources if the syscall doesn't end up returning. Oh well.
+ s.forgetConn(c)
+ }
+ go c.Close()
}
-func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- h.s.wg.Add(1)
- defer h.s.wg.Done() // a defer, in case ServeHTTP below panics
- h.h.ServeHTTP(w, r)
+// forgetConn removes c from the set of tracked conns and decrements it from the
+// waitgroup, unless it was previously removed.
+// s.mu must be held.
+func (s *Server) forgetConn(c net.Conn) {
+ if _, ok := s.conns[c]; ok {
+ delete(s.conns, c)
+ s.wg.Done()
+ }
}
// localhostCert is a PEM-encoded TLS cert with SAN IPs
diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go
index 500a9f0..90901ce 100644
--- a/src/net/http/httptest/server_test.go
+++ b/src/net/http/httptest/server_test.go
@@ -27,3 +27,30 @@ func TestServer(t *testing.T) {
t.Errorf("got %q, want hello", string(got))
}
}
+
+// Issue 12781
+func TestGetAfterClose(t *testing.T) {
+ ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != "hello" {
+ t.Fatalf("got %q, want hello", string(got))
+ }
+
+ ts.Close()
+
+ res, err = http.Get(ts.URL)
+ if err == nil {
+ body, _ := ioutil.ReadAll(res.Body)
+ t.Fatalf("Unexected response after close: %v, %v, %s", res.Status, res.Header, body)
+ }
+}