diff --git a/connection.go b/connection.go index 69ce477..c340512 100644 --- a/connection.go +++ b/connection.go @@ -280,8 +280,8 @@ func NewConnectionWithOptions(conn net.Conn, server bool, opts ...spdy.FramerOpt // Ping sends a ping frame across the connection and // returns the response time func (s *Connection) Ping() (time.Duration, error) { - pid := s.pingId s.pingLock.Lock() + pid := s.pingId if s.pingId > 0x7ffffffe { s.pingId = s.pingId - 0x7ffffffe } else { @@ -740,7 +740,43 @@ func (s *Connection) shutdown(closeTimeout time.Duration) { var err error select { case <-streamsClosed: - // No active streams, close should be safe + // No active streams; now drain inbound traffic before closing so the + // kernel sees FIN, not RST. Background: If a peer packet (e.g. a SPDY + // PING) arrives at our socket after Close(), the kernel responds with + // RST and discards anything still queued in OUR kernel send buffer. + // + // Half-closing the write end of the connection triggers a FIN once the + // send buffer is emptied and prevents new packets from being sent, but + // still allows the receive buffer to be drained. + func() { + cw, ok := s.conn.(interface{ CloseWrite() error }) + if !ok { + debugMessage("(%p) connection does not support half-close, skipping drain", s) + return + } + if err := cw.CloseWrite(); err != nil { + debugMessage("(%p) failed to half-close connection: %s", s, err) + return + } + var drainTimeout <-chan time.Time + if closeTimeout == time.Duration(0) { + // no close timeout configured; use fixed drain timeout to avoid + // hanging if peer does not respond or Serve() was not called + drainTimer := time.NewTimer(10 * time.Second) + defer drainTimer.Stop() + drainTimeout = drainTimer.C + } + select { + case <-s.closeChan: + return + case <-timeout: + debugMessage("(%p) close timeout reached", s) + return + case <-drainTimeout: + debugMessage("(%p) drain timeout reached", s) + return + } + }() err = s.conn.Close() case <-timeout: // Force ungraceful close diff --git a/spdy_test.go b/spdy_test.go index fdc7591..39f5633 100644 --- a/spdy_test.go +++ b/spdy_test.go @@ -26,6 +26,7 @@ import ( "net/http" "net/http/httptest" "sync" + "sync/atomic" "testing" "time" @@ -1004,6 +1005,152 @@ func TestGoAwayRace(t *testing.T) { done.Wait() } +func TestReceiveBufferIsDrained(t *testing.T) { + // 1. Server sends data until tcp window is full + // 2. Server closes stream and connection + // 3. Client starts reading and in between reads sends a ping + + // When the ping hits the server side while there is still data in the + // kernel send buffer, the kernel will respond with an RST and discard the + // remaining data. + + // Start listener + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error listening: %v", err) + } + defer listener.Close() + addr := listener.Addr().String() + t.Logf("Listening on: %s", addr) + + // Set up the client side + clientTcpConn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Error dialing server: %s", err) + } + clientSpdyConn, err := NewConnection(clientTcpConn, false) + if err != nil { + t.Fatalf("Error creating spdy connection: %s", err) + } + go clientSpdyConn.Serve(NoOpStreamHandler) + + // Set up the server side + serverStreamCh := make(chan *Stream, 1) + serverTcpConn, err := listener.Accept() + if err != nil { + t.Fatalf("Error accepting connection: %v", err) + } + serverSpdyConn, err := NewConnection(serverTcpConn, true) // spdy takes ownership of conn + if err != nil { + t.Fatalf("Error creating server connection: %v", err) + } + go serverSpdyConn.Serve(func(str *Stream) { + str.SendReply(http.Header{}, false) + serverStreamCh <- str + }) + + // Connect a client stream... + clientStream, err := clientSpdyConn.CreateStream(http.Header{}, nil, true) + if err != nil { + t.Fatalf("Error creating client stream: %v", err) + } + clientStream.Wait() // wait for reply + + // ... and wait for it on the server side + serverStream := <-serverStreamCh + + // Fill stream until backpressure occurs + var bytesWritten uint64 + var writingDone int32 + var writerWg sync.WaitGroup + writerWg.Add(1) + go func() { + defer writerWg.Done() + buf := make([]byte, 1024) + for atomic.LoadInt32(&writingDone) == 0 { + n, err := serverStream.Write(buf) + atomic.AddUint64(&bytesWritten, uint64(n)) + if err != nil { + return + } + } + }() + + // Wait until writer starts + for atomic.LoadUint64(&bytesWritten) == 0 { + time.Sleep(100 * time.Millisecond) + } + + // Wait until writer stalls + last := atomic.LoadUint64(&bytesWritten) + for { + time.Sleep(100 * time.Millisecond) + if last == atomic.LoadUint64(&bytesWritten) { + break + } + last = atomic.LoadUint64(&bytesWritten) + } + atomic.StoreInt32(&writingDone, 1) + + go func() { + // Close the server side + if err := serverStream.Close(); err != nil { + t.Errorf("Error closing stream: %v", err) + } + if err := serverStream.Reset(); err != nil { + t.Errorf("Error resetting stream: %v", err) + } + if err := serverSpdyConn.Close(); err != nil { + t.Errorf("Error closing spdy conn") + } + }() + + // Start sending pings + stopPings := make(chan struct{}) + defer close(stopPings) + go func() { + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + var wg sync.WaitGroup + for { + select { + case <-ticker.C: + wg.Add(1) + go func() { + defer wg.Done() + _, _ = clientSpdyConn.Ping() + }() + case <-stopPings: + wg.Wait() + return + } + } + }() + + // Start reading + var bytesRead uint64 + buf := make([]byte, 1024) + for { + time.Sleep(1 * time.Millisecond) // slow reader — keeps server send buffer non-empty + n, err := clientStream.Read(buf) + bytesRead += uint64(n) + if err != nil { + if err != io.EOF { + t.Logf("read stopped early (possible RST): %v", err) + } + break + } + } + _ = clientStream.Close() + writerWg.Wait() + + if bytesRead != atomic.LoadUint64(&bytesWritten) { + t.Errorf("Read less bytes than written: written: %d, read: %d", atomic.LoadUint64(&bytesWritten), bytesRead) + } else { + t.Logf("Successfully read all bytes: %d", bytesRead) + } +} + func TestSetIdleTimeoutAfterRemoteConnectionClosed(t *testing.T) { listener, err := net.Listen("tcp", "localhost:0") if err != nil {