diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b5623a..8155b75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.5.1] - 2026-05-11 + +### Fixed + +- **`writePump` no longer exits early on hub context cancellation** — shutdown now wakes the writer through the normal client close path (`client.send` / `client.done`) instead of selecting directly on `h.ctx.Done()`, giving queued messages and WebSocket close frames a chance to flush before the pump exits + ## [1.5.0] - 2026-04-07 ### Added diff --git a/client.go b/client.go index dd3ce4f..1cf6675 100644 --- a/client.go +++ b/client.go @@ -643,7 +643,19 @@ func (c *Client) writeCoalescedBatch(first sendItem, n int) bool { } // writePump pumps messages from the hub to the WebSocket connection. -func (c *Client) writePump(ctx context.Context) { +// +// Shutdown signals (in priority order): +// - CloseWithCode → closes c.send → exits via the send case with ok=false, +// sending a close frame. +// - handleUnregister (remote/abnormal close) → closes c.done → exits without +// a close frame (the connection is already gone). +// - hub.Shutdown → calls Close on each client → same path as CloseWithCode. +// +// We deliberately do NOT select on the hub's context. Hub shutdown closes +// every client's send channel, which already wakes the pump; adding a fourth +// select case measurably increases per-iteration cost (selectgo grows roughly +// linearly with case count) without changing correctness. +func (c *Client) writePump() { ticker := time.NewTicker(c.config.PingPeriod) defer func() { ticker.Stop() @@ -652,9 +664,6 @@ func (c *Client) writePump(ctx context.Context) { for { select { - case <-ctx.Done(): - return - case <-c.done: // Client was unregistered (remote/abnormal close). Exit // without sending a close frame — the connection is gone. diff --git a/client_test.go b/client_test.go index 844bdd1..e346c6d 100644 --- a/client_test.go +++ b/client_test.go @@ -979,3 +979,161 @@ func TestCoalesceWritesDisabledByDefault(t *testing.T) { } } } + +// waitFor polls fn at 10ms intervals until it returns true or timeout elapses. +func waitFor(t *testing.T, timeout time.Duration, fn func() bool, msg string) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if fn() { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("waitFor timed out after %s: %s", timeout, msg) +} + +func waitForClient(t *testing.T, hub *Hub) *Client { + t.Helper() + + var client *Client + waitFor(t, time.Second, func() bool { + clients := hub.Clients() + if len(clients) == 0 { + return false + } + client = clients[0] + return true + }, "client snapshot to include registered client") + return client +} + +// TestWritePumpExitsOnCloseWithCode verifies that calling CloseWithCode +// drives writePump to send a close frame and exit. With ctx.Done() removed +// from the writePump select, the path is: CloseWithCode → close(c.send) → +// writePump observes ok=false → writeCloseFrame → return. +func TestWritePumpExitsOnCloseWithCode(t *testing.T) { + hub, dial := setupClientTest(t) + conn := dial() + + client := waitForClient(t, hub) + if err := client.CloseWithCode(websocket.CloseGoingAway, "shutting down"); err != nil { + t.Fatalf("CloseWithCode: %v", err) + } + + // Client side should observe a close frame, then EOF. + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) + _, _, err := conn.ReadMessage() + if err == nil { + t.Fatal("expected close error from server, got nil") + } + ce, ok := err.(*websocket.CloseError) + if !ok { + t.Fatalf("expected *websocket.CloseError, got %T: %v", err, err) + } + if ce.Code != websocket.CloseGoingAway { + t.Errorf("close code = %d, want %d", ce.Code, websocket.CloseGoingAway) + } + + // writePump must have exited and unregistered the client. + waitFor(t, time.Second, func() bool { return hub.ClientCount() == 0 }, + "client to be unregistered after CloseWithCode") +} + +// TestWritePumpExitsOnRemoteClose verifies the unregister path: when the +// remote client closes the connection, readPump exits, handleUnregister +// runs, closeDone() fires, and writePump exits via the c.done case. +func TestWritePumpExitsOnRemoteClose(t *testing.T) { + hub, dial := setupClientTest(t) + conn := dial() + _ = waitForClient(t, hub) + + if err := conn.Close(); err != nil { + t.Fatalf("conn.Close: %v", err) + } + + waitFor(t, time.Second, func() bool { return hub.ClientCount() == 0 }, + "client to be unregistered after remote close") +} + +// TestWritePumpExitsOnHubShutdown verifies that hub.Shutdown drives every +// active writePump to exit. With ctx.Done() removed from writePump, the +// path is: Run sees h.ctx.Done() → calls client.Close() on each → close(c.send) +// → writePump exits via the send case with ok=false. +func TestWritePumpExitsOnHubShutdown(t *testing.T) { + hub := NewHub() + go hub.Run() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hub.UpgradeConnection(w, r) + })) + t.Cleanup(server.Close) + + const n = 5 + conns := make([]*websocket.Conn, 0, n) + dialer := websocket.Dialer{} + url := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws" + for range n { + c, _, err := dialer.Dial(url, nil) + if err != nil { + t.Fatalf("dial: %v", err) + } + t.Cleanup(func() { c.Close() }) + conns = append(conns, c) + } + waitFor(t, time.Second, func() bool { return hub.ClientCount() == n }, "all clients to register") + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := hub.Shutdown(ctx); err != nil { + t.Fatalf("Shutdown: %v", err) + } + + // Every dialer-side conn should now observe an error on read — the writePump + // has sent a close frame (or the conn is dead) and the readPump has exited. + for i, c := range conns { + _ = c.SetReadDeadline(time.Now().Add(time.Second)) + if _, _, err := c.ReadMessage(); err == nil { + t.Errorf("conn %d: expected error after hub.Shutdown, got nil", i) + } + } +} + +// TestWritePumpDeliversBufferedSendsBeforeClose verifies that messages +// already queued in c.send are delivered before writePump exits when +// CloseWithCode runs. The drainQueued path inside writePump is responsible +// for this; with ctx.Done() removed, the only exit signal is c.send being +// closed, which now happens AFTER buffered items are written. +func TestWritePumpDeliversBufferedSendsBeforeClose(t *testing.T) { + hub, dial := setupClientTest(t) + conn := dial() + + client := waitForClient(t, hub) + + const n = 5 + for i := range n { + if err := client.SendText(strings.Repeat("x", i+1)); err != nil { + t.Fatalf("Send %d: %v", i, err) + } + } + if err := client.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + got := 0 + for got < n { + _, msg, err := conn.ReadMessage() + if err != nil { + break + } + want := strings.Repeat("x", got+1) + if string(msg) != want { + t.Errorf("frame %d = %q, want %q", got, msg, want) + } + got++ + } + if got != n { + t.Errorf("delivered %d frames before close, want %d", got, n) + } +} diff --git a/hub.go b/hub.go index 76db852..b5a0aaa 100644 --- a/hub.go +++ b/hub.go @@ -1049,7 +1049,7 @@ func (h *Hub) UpgradeConnection(w http.ResponseWriter, r *http.Request, opts ... h.wg.Add(2) go func() { defer h.wg.Done() - client.writePump(h.ctx) + client.writePump() }() go func() { defer h.wg.Done()