Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ func NewConnectionWithOptions(conn net.Conn, server bool, opts ...spdy.FramerOpt
// Ping sends a ping frame across the connection and
// returns the response time
func (s *Connection) Ping() (time.Duration, error) {
pid := s.pingId
s.pingLock.Lock()
pid := s.pingId
if s.pingId > 0x7ffffffe {
s.pingId = s.pingId - 0x7ffffffe
} else {
Expand Down Expand Up @@ -740,7 +740,43 @@ func (s *Connection) shutdown(closeTimeout time.Duration) {
var err error
select {
case <-streamsClosed:
// No active streams, close should be safe
// No active streams; now drain inbound traffic before closing so the
// kernel sees FIN, not RST. Background: If a peer packet (e.g. a SPDY
// PING) arrives at our socket after Close(), the kernel responds with
// RST and discards anything still queued in OUR kernel send buffer.
//
// Half-closing the write end of the connection triggers a FIN once the
// send buffer is emptied and prevents new packets from being sent, but
// still allows the receive buffer to be drained.
func() {
cw, ok := s.conn.(interface{ CloseWrite() error })
if !ok {
debugMessage("(%p) connection does not support half-close, skipping drain", s)
return
}
if err := cw.CloseWrite(); err != nil {
debugMessage("(%p) failed to half-close connection: %s", s, err)
return
}
var drainTimeout <-chan time.Time
if closeTimeout == time.Duration(0) {
// no close timeout configured; use fixed drain timeout to avoid
// hanging if peer does not respond or Serve() was not called
drainTimer := time.NewTimer(10 * time.Second)
defer drainTimer.Stop()
drainTimeout = drainTimer.C
}
select {
case <-s.closeChan:
return
case <-timeout:
debugMessage("(%p) close timeout reached", s)
return
case <-drainTimeout:
debugMessage("(%p) drain timeout reached", s)
return
}
}()
err = s.conn.Close()
case <-timeout:
// Force ungraceful close
Expand Down
147 changes: 147 additions & 0 deletions spdy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -1004,6 +1005,152 @@ func TestGoAwayRace(t *testing.T) {
done.Wait()
}

func TestReceiveBufferIsDrained(t *testing.T) {
// 1. Server sends data until tcp window is full
// 2. Server closes stream and connection
// 3. Client starts reading and in between reads sends a ping

// When the ping hits the server side while there is still data in the
// kernel send buffer, the kernel will respond with an RST and discard the
// remaining data.

// Start listener
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error listening: %v", err)
}
defer listener.Close()
addr := listener.Addr().String()
t.Logf("Listening on: %s", addr)

// Set up the client side
clientTcpConn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error dialing server: %s", err)
}
clientSpdyConn, err := NewConnection(clientTcpConn, false)
if err != nil {
t.Fatalf("Error creating spdy connection: %s", err)
}
go clientSpdyConn.Serve(NoOpStreamHandler)

// Set up the server side
serverStreamCh := make(chan *Stream, 1)
serverTcpConn, err := listener.Accept()
if err != nil {
t.Fatalf("Error accepting connection: %v", err)
}
serverSpdyConn, err := NewConnection(serverTcpConn, true) // spdy takes ownership of conn
if err != nil {
t.Fatalf("Error creating server connection: %v", err)
}
go serverSpdyConn.Serve(func(str *Stream) {
str.SendReply(http.Header{}, false)
serverStreamCh <- str
})

// Connect a client stream...
clientStream, err := clientSpdyConn.CreateStream(http.Header{}, nil, true)
if err != nil {
t.Fatalf("Error creating client stream: %v", err)
}
clientStream.Wait() // wait for reply

// ... and wait for it on the server side
serverStream := <-serverStreamCh

// Fill stream until backpressure occurs
var bytesWritten uint64
var writingDone int32
var writerWg sync.WaitGroup
writerWg.Add(1)
go func() {
defer writerWg.Done()
buf := make([]byte, 1024)
for atomic.LoadInt32(&writingDone) == 0 {
n, err := serverStream.Write(buf)
atomic.AddUint64(&bytesWritten, uint64(n))
if err != nil {
return
}
}
}()

// Wait until writer starts
for atomic.LoadUint64(&bytesWritten) == 0 {
time.Sleep(100 * time.Millisecond)
}

// Wait until writer stalls
last := atomic.LoadUint64(&bytesWritten)
for {
time.Sleep(100 * time.Millisecond)
if last == atomic.LoadUint64(&bytesWritten) {
break
}
last = atomic.LoadUint64(&bytesWritten)
}
atomic.StoreInt32(&writingDone, 1)

go func() {
// Close the server side
if err := serverStream.Close(); err != nil {
t.Errorf("Error closing stream: %v", err)
}
if err := serverStream.Reset(); err != nil {
t.Errorf("Error resetting stream: %v", err)
}
if err := serverSpdyConn.Close(); err != nil {
t.Errorf("Error closing spdy conn")
}
}()

// Start sending pings
stopPings := make(chan struct{})
defer close(stopPings)
go func() {
ticker := time.NewTicker(5 * time.Millisecond)
defer ticker.Stop()
var wg sync.WaitGroup
for {
select {
case <-ticker.C:
wg.Add(1)
go func() {
defer wg.Done()
_, _ = clientSpdyConn.Ping()
}()
case <-stopPings:
wg.Wait()
return
}
}
}()

// Start reading
var bytesRead uint64
buf := make([]byte, 1024)
for {
time.Sleep(1 * time.Millisecond) // slow reader — keeps server send buffer non-empty
n, err := clientStream.Read(buf)
bytesRead += uint64(n)
if err != nil {
if err != io.EOF {
t.Logf("read stopped early (possible RST): %v", err)
}
break
}
}
_ = clientStream.Close()
writerWg.Wait()

if bytesRead != atomic.LoadUint64(&bytesWritten) {
t.Errorf("Read less bytes than written: written: %d, read: %d", atomic.LoadUint64(&bytesWritten), bytesRead)
} else {
t.Logf("Successfully read all bytes: %d", bytesRead)
}
}

func TestSetIdleTimeoutAfterRemoteConnectionClosed(t *testing.T) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
Expand Down