From d75630d8b499ca648672fe61f7eca69123e564eb Mon Sep 17 00:00:00 2001 From: GRACENOBLE Date: Mon, 15 Jun 2026 16:26:44 +0300 Subject: [PATCH 1/2] feat(websocket): add real-time WebSocket support via gorilla/websocket Implements issue #21 across all three layers: **Backend:** gorilla/websocket Hub with channel-based fan-out, GET /ws endpoint authenticated via Firebase token query param, hub wired as a long-lived goroutine in main.go with graceful shutdown. Handler struct now carries verifier and hub; RegisterRoutes signature simplified. **Web:** useWebSocket hook with typed WsEnvelope, exponential-backoff reconnection (capped at 30s), Firebase token as ?token= query param, and cleanup on unmount. Ref mutations moved to useLayoutEffect per React 19 react-hooks/refs rule. **Mobile:** OkHttp WebSocketManager with injectable factory and reconnect scheduler for testability, WebSocketViewModel exposing StateFlow, kotlinx.serialization for JSON parsing. Hub.Publish() left as the integration point for issue #18 (Asynq + Redis Streams). --- backend/cmd/api/main.go | 9 +- backend/docs/_index.md | 1 + backend/docs/routing.md | 25 ++- backend/docs/swagger/docs.go | 38 ++++ backend/docs/swagger/swagger.json | 38 ++++ backend/docs/swagger/swagger.yaml | 26 +++ backend/docs/websocket.md | 176 ++++++++++++++++++ backend/go.mod | 1 + backend/go.sum | 2 + backend/internal/infrastructure/ws/client.go | 96 ++++++++++ backend/internal/infrastructure/ws/hub.go | 72 +++++++ .../internal/infrastructure/ws/hub_test.go | 133 +++++++++++++ backend/internal/infrastructure/ws/message.go | 9 + backend/internal/server/server.go | 7 +- .../internal/transport/handlers/handler.go | 7 +- .../transport/handlers/health_handler_test.go | 4 +- backend/internal/transport/handlers/routes.go | 10 +- .../internal/transport/handlers/ws_handler.go | 51 +++++ mobile/app/build.gradle.kts | 6 + .../template/websocket/WebSocketManager.kt | 114 ++++++++++++ .../template/websocket/WebSocketViewModel.kt | 44 +++++ .../company/template/websocket/WsEnvelope.kt | 14 ++ .../company/template/websocket/FakeOkHttp.kt | 36 ++++ .../websocket/WebSocketManagerTest.kt | 140 ++++++++++++++ .../websocket/WebSocketViewModelTest.kt | 87 +++++++++ mobile/gradle/libs.versions.toml | 9 + web/docs/_index.md | 1 + web/docs/websocket.md | 118 ++++++++++++ web/lib/useWebSocket.test.ts | 149 +++++++++++++++ web/lib/useWebSocket.ts | 108 +++++++++++ 30 files changed, 1508 insertions(+), 23 deletions(-) create mode 100644 backend/docs/websocket.md create mode 100644 backend/internal/infrastructure/ws/client.go create mode 100644 backend/internal/infrastructure/ws/hub.go create mode 100644 backend/internal/infrastructure/ws/hub_test.go create mode 100644 backend/internal/infrastructure/ws/message.go create mode 100644 backend/internal/transport/handlers/ws_handler.go create mode 100644 mobile/app/src/main/java/com/company/template/websocket/WebSocketManager.kt create mode 100644 mobile/app/src/main/java/com/company/template/websocket/WebSocketViewModel.kt create mode 100644 mobile/app/src/main/java/com/company/template/websocket/WsEnvelope.kt create mode 100644 mobile/app/src/test/java/com/company/template/websocket/FakeOkHttp.kt create mode 100644 mobile/app/src/test/java/com/company/template/websocket/WebSocketManagerTest.kt create mode 100644 mobile/app/src/test/java/com/company/template/websocket/WebSocketViewModelTest.kt create mode 100644 web/docs/websocket.md create mode 100644 web/lib/useWebSocket.test.ts create mode 100644 web/lib/useWebSocket.ts diff --git a/backend/cmd/api/main.go b/backend/cmd/api/main.go index 66e9bcd..777c6fe 100644 --- a/backend/cmd/api/main.go +++ b/backend/cmd/api/main.go @@ -11,6 +11,7 @@ import ( "time" "backend/internal/bootstrap" + "backend/internal/infrastructure/ws" "backend/internal/server" ) @@ -55,7 +56,11 @@ func main() { os.Exit(1) } - srv := server.NewServer(app) + hubCtx, hubCancel := context.WithCancel(context.Background()) + hub := ws.NewHub() + go hub.Run(hubCtx) + + srv := server.NewServer(app, hub) slog.Info("API docs", "url", fmt.Sprintf("http://localhost%s/swagger/index.html", srv.Addr)) done := make(chan bool, 1) @@ -66,6 +71,8 @@ func main() { } <-done + hubCancel() // stop hub after all WS connections have been closed by server shutdown + if app.Cache != nil { if err := app.Cache.Close(); err != nil { slog.Error("cache close error", "error", err) diff --git a/backend/docs/_index.md b/backend/docs/_index.md index 6356f74..17b8974 100644 --- a/backend/docs/_index.md +++ b/backend/docs/_index.md @@ -15,3 +15,4 @@ The `docs` agent reads this index first to locate the right file before diving i | Middleware (logger, rate limiter) | [middleware.md](middleware.md) | `internal/transport/middleware/logger.go`, `internal/transport/middleware/ratelimit.go`, `internal/transport/handlers/routes.go` | | Firebase Auth (token verification, middleware, MeHandler) | [auth.md](auth.md) | `internal/usecase/auth_usecase.go`, `internal/transport/middleware/auth.go`, `internal/transport/handlers/auth_handler.go`, `pkg/firebase/admin.go`, `internal/bootstrap/bootstrap.go` | | Observability (Sentry error tracking) | [observability.md](observability.md) | `internal/transport/middleware/sentry.go`, `internal/bootstrap/bootstrap.go`, `internal/transport/handlers/routes.go` | +| WebSocket (Hub, client, GET /ws, auth, wiring) | [websocket.md](websocket.md) | `internal/infrastructure/ws/`, `internal/transport/handlers/ws_handler.go`, `internal/transport/handlers/routes.go`, `internal/server/server.go`, `cmd/api/main.go` | diff --git a/backend/docs/routing.md b/backend/docs/routing.md index dc4221d..b18fd20 100644 --- a/backend/docs/routing.md +++ b/backend/docs/routing.md @@ -16,16 +16,18 @@ sources: ## Handler struct ```go -// internal/handler/handler.go +// internal/transport/handlers/handler.go type Handler struct { healthUC usecase.HealthUseCase + verifier usecase.FirebaseTokenVerifier // nil disables auth (dev only) + hub *ws.Hub } -func NewHandler(healthUC usecase.HealthUseCase) *Handler { - return &Handler{healthUC: healthUC} +func NewHandler(healthUC usecase.HealthUseCase, verifier usecase.FirebaseTokenVerifier, hub *ws.Hub) *Handler { + return &Handler{healthUC: healthUC, verifier: verifier, hub: hub} } ``` -The `Handler` struct holds use case interfaces — not `*sql.DB` directly. Add new use case fields here as features are added. +The `Handler` struct holds use case interfaces and infrastructure dependencies — not `*sql.DB` directly. `verifier` is stored on the struct (not passed to `RegisterRoutes`) so the WebSocket handler can read it inline for query-param auth. ## Wiring (server.go) `internal/server/server.go` contains `NewServer(app *bootstrap.App) *http.Server` — wiring only, no logic. @@ -34,11 +36,11 @@ It receives the already-validated `*bootstrap.App` (which holds `*sql.DB`, `Cach ```go healthRepo := postgres.NewHealthRepository(app.DB) healthUC := usecase.NewHealthUseCase(healthRepo) -h := handlers.NewHandler(healthUC) +h := handlers.NewHandler(healthUC, app.Firebase, hub) return &http.Server{ Addr: fmt.Sprintf(":%d", app.Config.Port), - Handler: h.RegisterRoutes(app.Config.RateLimitRPS, app.Config.RateLimitBurst, app.Firebase, app.Config.SentryDSN), + Handler: h.RegisterRoutes(app.Config.RateLimitRPS, app.Config.RateLimitBurst, app.Config.SentryDSN), IdleTimeout: time.Minute, ReadTimeout: 10 * time.Second, WriteTimeout: 30 * time.Second, @@ -51,7 +53,7 @@ All routes registered in `RegisterRoutes()` on `*Handler`, which returns `http.H `verifier` is a `usecase.FirebaseTokenVerifier`; pass `nil` to skip Firebase auth (development only — see [auth](auth.md)). ```go -func (h *Handler) RegisterRoutes(rps float64, burst int, verifier usecase.FirebaseTokenVerifier, sentryDSN string) http.Handler { +func (h *Handler) RegisterRoutes(rps float64, burst int, sentryDSN string) http.Handler { r := gin.New() // Gin's colorful logger locally; structured slog logger in staging/production. @@ -67,11 +69,13 @@ func (h *Handler) RegisterRoutes(rps float64, burst int, verifier usecase.Fireba r.GET("/", h.HelloWorldHandler) r.GET("/health", h.HealthHandler) + r.GET("/ws", h.WsHandler) // WebSocket upgrade — auth via ?token= query param + r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) api := r.Group("/api/v1") - if verifier != nil { - api.Use(middleware.FirebaseAuth(verifier)) + if h.verifier != nil { + api.Use(middleware.FirebaseAuth(h.verifier)) } api.GET("/me", h.MeHandler) @@ -104,7 +108,8 @@ Allowed methods: GET, POST, PUT, DELETE, OPTIONS, PATCH. |---|---|---|---|---| | GET | `/` | none | `HelloWorldHandler` — returns `{"message": "Hello World"}` | `hello_handler.go` | | GET | `/health` | none | `HealthHandler` — returns `HealthStats`; 503 when DB is down | `health_handler.go` | -| GET | `/api/v1/me` | FirebaseAuth | `MeHandler` — returns verified `FirebaseToken` claims | `auth_handler.go` | +| GET | `/ws` | `?token=` query param | `WsHandler` — upgrades to WebSocket; 401 when token missing/invalid | `ws_handler.go` | +| GET | `/api/v1/me` | FirebaseAuth header | `MeHandler` — returns verified `FirebaseToken` claims | `auth_handler.go` | ## Graceful shutdown Wired in `cmd/api/main.go` via `signal.NotifyContext` for SIGINT/SIGTERM. diff --git a/backend/docs/swagger/docs.go b/backend/docs/swagger/docs.go index 8a74d01..2d1b049 100644 --- a/backend/docs/swagger/docs.go +++ b/backend/docs/swagger/docs.go @@ -97,6 +97,44 @@ const docTemplate = `{ } } } + }, + "/ws": { + "get": { + "description": "Upgrades HTTP to WebSocket. Pass a Firebase ID token as ` + "`" + `?token=\u003ctoken\u003e` + "`" + `. Returns 401 when the token is missing or invalid.", + "produces": [ + "application/json" + ], + "tags": [ + "websocket" + ], + "summary": "Open a WebSocket connection", + "parameters": [ + { + "type": "string", + "description": "Firebase ID token", + "name": "token", + "in": "query", + "required": true + } + ], + "responses": { + "101": { + "description": "Switching Protocols", + "schema": { + "type": "string" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } } }, "definitions": { diff --git a/backend/docs/swagger/swagger.json b/backend/docs/swagger/swagger.json index a9fef60..6af64f7 100644 --- a/backend/docs/swagger/swagger.json +++ b/backend/docs/swagger/swagger.json @@ -91,6 +91,44 @@ } } } + }, + "/ws": { + "get": { + "description": "Upgrades HTTP to WebSocket. Pass a Firebase ID token as `?token=\u003ctoken\u003e`. Returns 401 when the token is missing or invalid.", + "produces": [ + "application/json" + ], + "tags": [ + "websocket" + ], + "summary": "Open a WebSocket connection", + "parameters": [ + { + "type": "string", + "description": "Firebase ID token", + "name": "token", + "in": "query", + "required": true + } + ], + "responses": { + "101": { + "description": "Switching Protocols", + "schema": { + "type": "string" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } } }, "definitions": { diff --git a/backend/docs/swagger/swagger.yaml b/backend/docs/swagger/swagger.yaml index 1cd566c..8c5a43a 100644 --- a/backend/docs/swagger/swagger.yaml +++ b/backend/docs/swagger/swagger.yaml @@ -97,6 +97,32 @@ paths: summary: Health check tags: - ops + /ws: + get: + description: Upgrades HTTP to WebSocket. Pass a Firebase ID token as `?token=`. + Returns 401 when the token is missing or invalid. + parameters: + - description: Firebase ID token + in: query + name: token + required: true + type: string + produces: + - application/json + responses: + "101": + description: Switching Protocols + schema: + type: string + "401": + description: Unauthorized + schema: + additionalProperties: + type: string + type: object + summary: Open a WebSocket connection + tags: + - websocket securityDefinitions: BearerAuth: description: Firebase ID token — prefix with "Bearer " diff --git a/backend/docs/websocket.md b/backend/docs/websocket.md new file mode 100644 index 0000000..d2e8fb9 --- /dev/null +++ b/backend/docs/websocket.md @@ -0,0 +1,176 @@ +--- +topic: websocket +last_verified: 2026-06-15 +sources: + - internal/infrastructure/ws/message.go + - internal/infrastructure/ws/hub.go + - internal/infrastructure/ws/client.go + - internal/infrastructure/ws/hub_test.go + - internal/transport/handlers/ws_handler.go + - internal/transport/handlers/routes.go + - internal/server/server.go + - cmd/api/main.go +--- + +# WebSocket + +## Overview + +Real-time bidirectional communication is provided via `github.com/gorilla/websocket`. +A `Hub` runs as a long-lived goroutine and fans out messages to all connected clients. +The `GET /ws` endpoint upgrades HTTP connections; a Firebase ID token is required as a +query parameter. + +## Message envelope + +All messages use a typed JSON envelope defined in `internal/infrastructure/ws/message.go`: + +```go +type Envelope struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload"` +} +``` + +`Type` is a dot-separated event name (e.g. `"job.completed"`). `Payload` is arbitrary +JSON whose shape is determined by `Type`. + +## Hub + +`internal/infrastructure/ws/hub.go` + +```go +type Hub struct { + clients map[*Client]struct{} + broadcast chan []byte // buffered, capacity 256 + Register chan *Client + Unregister chan *Client +} + +func NewHub() *Hub +func (h *Hub) Run(ctx context.Context) // blocking; cancel ctx to stop +func (h *Hub) Publish(msgType string, payload any) error +``` + +`Run` must be called in its own goroutine and runs until `ctx` is cancelled. +`Publish` marshals `payload` into an `Envelope` and queues it for broadcast — +safe to call from any goroutine (e.g. an Asynq worker in future #18). + +### Goroutine model + +Each WebSocket connection spawns two goroutines: `ReadPump` and `WritePump` (on `Client`). +The Hub serialises all mutations (register / unregister / broadcast) through a `select` loop +so no locking is needed on its internal `clients` map. + +``` + caller goroutine + │ hub.Publish(...) + ▼ + hub.broadcast chan + │ + hub.Run goroutine ──► client.Send chan ──► client.WritePump goroutine ──► WebSocket conn + client.ReadPump goroutine ──► (discards incoming / handles pings) +``` + +Slow clients are dropped: if `client.Send` is full, the Hub closes the channel and +removes the client without blocking the broadcast loop. + +## Client + +`internal/infrastructure/ws/client.go` + +```go +type Client struct { + hub *Hub + conn *websocket.Conn + Send chan []byte // exported for testing +} + +func NewClient(hub *Hub, conn *websocket.Conn) *Client // registers with hub +func (c *Client) ReadPump() // must run in goroutine +func (c *Client) WritePump() // must run in goroutine +``` + +Ping/pong keepalive: `pingPeriod = 54s`, `pongWait = 60s`, `writeWait = 10s`. + +## Route — GET /ws + +``` +GET /ws?token= +``` + +Defined in `internal/transport/handlers/ws_handler.go`. Registered in `RegisterRoutes` +outside the `/api/v1` auth group — auth is handled inline because WebSocket clients +cannot set `Authorization` headers. + +**Auth flow:** +1. If `h.verifier != nil` (staging / production): reads `?token=` query param. + Returns `401` when missing or when `VerifyIDToken` fails. +2. If `h.verifier == nil` (development): skips auth — connects immediately. + +After successful auth, the connection is upgraded and `ReadPump` / `WritePump` are +started in separate goroutines. + +## Wiring in server.go and main.go + +`server.go` accepts `*ws.Hub` as a second argument: + +```go +func NewServer(app *bootstrap.App, hub *ws.Hub) *http.Server +``` + +`cmd/api/main.go` creates the Hub, starts `Run` with a child context, and cancels +it after the HTTP server shuts down (so all in-flight connections close first): + +```go +hubCtx, hubCancel := context.WithCancel(context.Background()) +hub := ws.NewHub() +go hub.Run(hubCtx) + +srv := server.NewServer(app, hub) +// ... +<-done +hubCancel() // stop hub after server drains connections +``` + +## Handler struct + +`verifier` (for WS auth) and `hub` are now fields on `Handler`: + +```go +type Handler struct { + healthUC usecase.HealthUseCase + verifier usecase.FirebaseTokenVerifier // nil disables auth (dev only) + hub *ws.Hub +} + +func NewHandler(healthUC usecase.HealthUseCase, verifier usecase.FirebaseTokenVerifier, hub *ws.Hub) *Handler +``` + +`RegisterRoutes` no longer accepts `verifier` as a parameter — it reads from `h.verifier`. + +## Publishing events from workers (future #18) + +Call `hub.Publish` from any goroutine: + +```go +hub.Publish("job.completed", map[string]any{ + "jobId": id, + "status": "done", +}) +``` + +When #18 (Asynq + Redis Streams) lands, the Asynq task handlers and Redis Streams +consumers will call this method to push domain events to connected clients. + +## Testing + +Unit tests in `internal/infrastructure/ws/hub_test.go` cover: +- `TestHub_RegisterAndBroadcast` — single client receives broadcast +- `TestHub_UnregisterRemovesClient` — Send channel is closed on unregister +- `TestHub_ContextCancelClosesSendChannels` — all channels closed on ctx cancel +- `TestHub_ConcurrentClientsAndBroadcast` — 10 concurrent clients all receive +- `TestHub_Publish` — JSON marshalling and delivery + +Tests inject `*Client` with a nil `conn` and a buffered `Send` channel; the Hub only +touches `client.Send`, not the connection, so real WebSocket connections are not needed. diff --git a/backend/go.mod b/backend/go.mod index 3522f03..b5a6056 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -101,6 +101,7 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.16 // indirect github.com/googleapis/gax-go/v2 v2.22.0 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect diff --git a/backend/go.sum b/backend/go.sum index 87494dd..95314cd 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -205,6 +205,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.16 h1:F/VPrx0YPBdksZJQdC github.com/googleapis/enterprise-certificate-proxy v0.3.16/go.mod h1:9Yb0eAkH/Xqhvv3zbeKf/+wMJqCeocWc6KIhDvEAuYE= github.com/googleapis/gax-go/v2 v2.22.0 h1:PjIWBpgGIVKGoCXuiCoP64altEJCj3/Ei+kSU5vlZD4= github.com/googleapis/gax-go/v2 v2.22.0/go.mod h1:irWBbALSr0Sk3qlqb9SyJ1h68WjgeFuiOzI4Rqw5+aY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= diff --git a/backend/internal/infrastructure/ws/client.go b/backend/internal/infrastructure/ws/client.go new file mode 100644 index 0000000..e39effa --- /dev/null +++ b/backend/internal/infrastructure/ws/client.go @@ -0,0 +1,96 @@ +package ws + +import ( + "log/slog" + "time" + + "github.com/gorilla/websocket" +) + +const ( + writeWait = 10 * time.Second + pongWait = 60 * time.Second + pingPeriod = (pongWait * 9) / 10 + maxMessageSize = 512 +) + +// Client couples a single WebSocket connection to the Hub. +// Send is exported so hub_test.go can inject a client without a real conn. +type Client struct { + hub *Hub + conn *websocket.Conn + Send chan []byte +} + +// NewClient allocates a Client and registers it with the hub. +// Call WritePump and ReadPump in separate goroutines after this returns. +func NewClient(hub *Hub, conn *websocket.Conn) *Client { + c := &Client{ + hub: hub, + conn: conn, + Send: make(chan []byte, 256), + } + hub.Register <- c + return c +} + +// ReadPump reads from the WebSocket and handles connection liveness. +// It unregisters the client and closes the connection when done. +func (c *Client) ReadPump() { + defer func() { + c.hub.Unregister <- c + c.conn.Close() + }() + c.conn.SetReadLimit(maxMessageSize) + c.conn.SetReadDeadline(time.Now().Add(pongWait)) //nolint:errcheck + c.conn.SetPongHandler(func(string) error { + return c.conn.SetReadDeadline(time.Now().Add(pongWait)) + }) + for { + _, _, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + slog.Error("websocket read error", "error", err) + } + break + } + // Server-to-client only for now; incoming messages are discarded. + } +} + +// WritePump writes queued messages to the WebSocket and sends periodic pings. +func (c *Client) WritePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.conn.Close() + }() + for { + select { + case msg, ok := <-c.Send: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) //nolint:errcheck + if !ok { + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) //nolint:errcheck + return + } + w, err := c.conn.NextWriter(websocket.TextMessage) + if err != nil { + return + } + w.Write(msg) //nolint:errcheck + // Drain any queued messages into the same WebSocket frame. + for i := len(c.Send); i > 0; i-- { + w.Write([]byte{'\n'}) //nolint:errcheck + w.Write(<-c.Send) //nolint:errcheck + } + if err := w.Close(); err != nil { + return + } + case <-ticker.C: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) //nolint:errcheck + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} diff --git a/backend/internal/infrastructure/ws/hub.go b/backend/internal/infrastructure/ws/hub.go new file mode 100644 index 0000000..0497294 --- /dev/null +++ b/backend/internal/infrastructure/ws/hub.go @@ -0,0 +1,72 @@ +package ws + +import ( + "context" + "encoding/json" +) + +// Hub maintains the set of active WebSocket clients and broadcasts messages to them. +// All mutations are serialised through the Run goroutine — no locking needed on the +// clients map itself. +type Hub struct { + clients map[*Client]struct{} + broadcast chan []byte + Register chan *Client + Unregister chan *Client +} + +// NewHub allocates a Hub with buffered channels. +func NewHub() *Hub { + return &Hub{ + clients: make(map[*Client]struct{}), + broadcast: make(chan []byte, 256), + Register: make(chan *Client), + Unregister: make(chan *Client), + } +} + +// Run processes register, unregister, and broadcast events until ctx is cancelled. +// Call this in its own goroutine. +func (h *Hub) Run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + for c := range h.clients { + close(c.Send) + } + return + case c := <-h.Register: + h.clients[c] = struct{}{} + case c := <-h.Unregister: + if _, ok := h.clients[c]; ok { + delete(h.clients, c) + close(c.Send) + } + case msg := <-h.broadcast: + for c := range h.clients { + select { + case c.Send <- msg: + default: + // Slow client: drop and disconnect. + close(c.Send) + delete(h.clients, c) + } + } + } + } +} + +// Publish marshals msgType + payload into an Envelope and queues it for broadcast. +// Safe to call from any goroutine (e.g. an Asynq worker or Redis Streams consumer). +func (h *Hub) Publish(msgType string, payload any) error { + raw, err := json.Marshal(payload) + if err != nil { + return err + } + env, err := json.Marshal(Envelope{Type: msgType, Payload: raw}) + if err != nil { + return err + } + h.broadcast <- env + return nil +} diff --git a/backend/internal/infrastructure/ws/hub_test.go b/backend/internal/infrastructure/ws/hub_test.go new file mode 100644 index 0000000..ee017be --- /dev/null +++ b/backend/internal/infrastructure/ws/hub_test.go @@ -0,0 +1,133 @@ +package ws + +import ( + "context" + "sync" + "testing" + "time" +) + +func newTestClient(hub *Hub) *Client { + return &Client{hub: hub, Send: make(chan []byte, 16)} +} + +func TestHub_RegisterAndBroadcast(t *testing.T) { + hub := NewHub() + go hub.Run(t.Context()) + + c := newTestClient(hub) + hub.Register <- c + time.Sleep(5 * time.Millisecond) + + msg := []byte(`{"type":"ping","payload":null}`) + hub.broadcast <- msg + + select { + case got := <-c.Send: + if string(got) != string(msg) { + t.Errorf("got %q, want %q", got, msg) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for broadcast") + } +} + +func TestHub_UnregisterRemovesClient(t *testing.T) { + hub := NewHub() + go hub.Run(t.Context()) + + c := newTestClient(hub) + hub.Register <- c + time.Sleep(5 * time.Millisecond) + + hub.Unregister <- c + time.Sleep(5 * time.Millisecond) + + select { + case _, ok := <-c.Send: + if ok { + t.Error("expected Send channel to be closed") + } + default: + t.Error("expected Send channel to be closed but it was still open") + } +} + +func TestHub_ContextCancelClosesSendChannels(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + hub := NewHub() + go hub.Run(ctx) + + c := newTestClient(hub) + hub.Register <- c + time.Sleep(5 * time.Millisecond) + + cancel() + time.Sleep(10 * time.Millisecond) + + select { + case _, ok := <-c.Send: + if ok { + t.Error("expected Send channel to be closed after context cancel") + } + default: + t.Error("expected Send channel to be closed but it was still open") + } +} + +func TestHub_ConcurrentClientsAndBroadcast(t *testing.T) { + hub := NewHub() + go hub.Run(t.Context()) + + const n = 10 + clients := make([]*Client, n) + for i := range clients { + clients[i] = newTestClient(hub) + hub.Register <- clients[i] + } + time.Sleep(10 * time.Millisecond) + + msg := []byte(`{"type":"hello","payload":null}`) + + var wg sync.WaitGroup + for _, c := range clients { + wg.Add(1) + go func(cl *Client) { + defer wg.Done() + select { + case got := <-cl.Send: + if string(got) != string(msg) { + t.Errorf("client got %q, want %q", got, msg) + } + case <-time.After(time.Second): + t.Error("client timed out waiting for broadcast") + } + }(c) + } + + hub.broadcast <- msg + wg.Wait() +} + +func TestHub_Publish(t *testing.T) { + hub := NewHub() + go hub.Run(t.Context()) + + c := newTestClient(hub) + hub.Register <- c + time.Sleep(5 * time.Millisecond) + + if err := hub.Publish("job.completed", map[string]string{"id": "123"}); err != nil { + t.Fatalf("Publish() error: %v", err) + } + + select { + case got := <-c.Send: + if len(got) == 0 { + t.Error("expected non-empty message from Publish") + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for published message") + } +} diff --git a/backend/internal/infrastructure/ws/message.go b/backend/internal/infrastructure/ws/message.go new file mode 100644 index 0000000..0916ed7 --- /dev/null +++ b/backend/internal/infrastructure/ws/message.go @@ -0,0 +1,9 @@ +package ws + +import "encoding/json" + +// Envelope is the typed wire format for all WebSocket messages. +type Envelope struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload"` +} diff --git a/backend/internal/server/server.go b/backend/internal/server/server.go index 40774ba..68a53c0 100644 --- a/backend/internal/server/server.go +++ b/backend/internal/server/server.go @@ -9,12 +9,13 @@ import ( "backend/internal/bootstrap" "backend/internal/infrastructure/database/postgres" + "backend/internal/infrastructure/ws" "backend/internal/transport/handlers" "backend/internal/usecase" ) // NewServer wires all layers and returns a configured *http.Server. -func NewServer(app *bootstrap.App) *http.Server { +func NewServer(app *bootstrap.App, hub *ws.Hub) *http.Server { switch app.Config.Env { case "staging", "production": gin.SetMode(gin.ReleaseMode) @@ -28,11 +29,11 @@ func NewServer(app *bootstrap.App) *http.Server { healthRepo := postgres.NewHealthRepository(app.DB) healthUC := usecase.NewHealthUseCase(healthRepo) - h := handlers.NewHandler(healthUC) + h := handlers.NewHandler(healthUC, app.Firebase, hub) return &http.Server{ Addr: fmt.Sprintf(":%d", app.Config.Port), - Handler: h.RegisterRoutes(app.Config.RateLimitRPS, app.Config.RateLimitBurst, app.Firebase, app.Config.SentryDSN), + Handler: h.RegisterRoutes(app.Config.RateLimitRPS, app.Config.RateLimitBurst, app.Config.SentryDSN), IdleTimeout: time.Minute, ReadTimeout: 10 * time.Second, WriteTimeout: 30 * time.Second, diff --git a/backend/internal/transport/handlers/handler.go b/backend/internal/transport/handlers/handler.go index e738fb6..02f133c 100644 --- a/backend/internal/transport/handlers/handler.go +++ b/backend/internal/transport/handlers/handler.go @@ -1,15 +1,18 @@ package handlers import ( + "backend/internal/infrastructure/ws" "backend/internal/usecase" ) // Handler holds all use case dependencies for HTTP handlers. type Handler struct { healthUC usecase.HealthUseCase + verifier usecase.FirebaseTokenVerifier // nil disables auth (dev only) + hub *ws.Hub } // NewHandler constructs a Handler with all required use cases. -func NewHandler(healthUC usecase.HealthUseCase) *Handler { - return &Handler{healthUC: healthUC} +func NewHandler(healthUC usecase.HealthUseCase, verifier usecase.FirebaseTokenVerifier, hub *ws.Hub) *Handler { + return &Handler{healthUC: healthUC, verifier: verifier, hub: hub} } diff --git a/backend/internal/transport/handlers/health_handler_test.go b/backend/internal/transport/handlers/health_handler_test.go index 2669911..566edbb 100644 --- a/backend/internal/transport/handlers/health_handler_test.go +++ b/backend/internal/transport/handlers/health_handler_test.go @@ -31,7 +31,7 @@ func TestHealthHandler_Success(t *testing.T) { Status: "up", Message: "It's healthy", } - h := NewHandler(&mockHealthUC{stats: want}) + h := NewHandler(&mockHealthUC{stats: want}, nil, nil) r := gin.New() r.GET("/health", h.HealthHandler) @@ -50,7 +50,7 @@ func TestHealthHandler_Success(t *testing.T) { } func TestHealthHandler_ServiceUnavailable(t *testing.T) { - h := NewHandler(&mockHealthUC{err: errors.New("connection refused")}) + h := NewHandler(&mockHealthUC{err: errors.New("connection refused")}, nil, nil) r := gin.New() r.GET("/health", h.HealthHandler) diff --git a/backend/internal/transport/handlers/routes.go b/backend/internal/transport/handlers/routes.go index 9d9ef8e..efe88d4 100644 --- a/backend/internal/transport/handlers/routes.go +++ b/backend/internal/transport/handlers/routes.go @@ -10,14 +10,13 @@ import ( _ "backend/docs/swagger" "backend/internal/transport/middleware" - "backend/internal/usecase" ) // RegisterRoutes creates the Gin engine, applies middleware, and registers all routes. // rps and burst configure IP-based rate limiting; pass rps<=0 to disable. -// verifier enables Firebase token auth on protected routes; pass nil to skip auth (dev only). // sentryDSN enables Sentry error tracking; pass empty string to disable. -func (h *Handler) RegisterRoutes(rps float64, burst int, verifier usecase.FirebaseTokenVerifier, sentryDSN string) http.Handler { +// Firebase auth is read from h.verifier; nil disables auth (dev only). +func (h *Handler) RegisterRoutes(rps float64, burst int, sentryDSN string) http.Handler { r := gin.New() r.Use(middleware.SentryMiddleware(sentryDSN)) @@ -40,12 +39,13 @@ func (h *Handler) RegisterRoutes(rps float64, burst int, verifier usecase.Fireba r.GET("/", h.HelloWorldHandler) r.GET("/health", h.HealthHandler) + r.GET("/ws", h.WsHandler) r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) api := r.Group("/api/v1") - if verifier != nil { - api.Use(middleware.FirebaseAuth(verifier)) + if h.verifier != nil { + api.Use(middleware.FirebaseAuth(h.verifier)) } api.GET("/me", h.MeHandler) diff --git a/backend/internal/transport/handlers/ws_handler.go b/backend/internal/transport/handlers/ws_handler.go new file mode 100644 index 0000000..66aa030 --- /dev/null +++ b/backend/internal/transport/handlers/ws_handler.go @@ -0,0 +1,51 @@ +package handlers + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + + ws "backend/internal/infrastructure/ws" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + // TODO: restrict to known origins in production. + CheckOrigin: func(r *http.Request) bool { return true }, +} + +// WsHandler godoc +// +// @Summary Open a WebSocket connection +// @Description Upgrades HTTP to WebSocket. Pass a Firebase ID token as `?token=`. Returns 401 when the token is missing or invalid. +// @Tags websocket +// @Produce json +// @Param token query string true "Firebase ID token" +// @Success 101 {string} string "Switching Protocols" +// @Failure 401 {object} map[string]string +// @Router /ws [get] +func (h *Handler) WsHandler(c *gin.Context) { + if h.verifier != nil { + token := c.Query("token") + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "missing token"}) + return + } + if _, err := h.verifier.VerifyIDToken(c.Request.Context(), token); err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired token"}) + return + } + } + + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + // Upgrader already wrote the error response. + return + } + + client := ws.NewClient(h.hub, conn) + go client.WritePump() + go client.ReadPump() +} diff --git a/mobile/app/build.gradle.kts b/mobile/app/build.gradle.kts index 18ff9b9..55aab6f 100644 --- a/mobile/app/build.gradle.kts +++ b/mobile/app/build.gradle.kts @@ -3,6 +3,7 @@ import java.util.Properties plugins { alias(libs.plugins.android.application) alias(libs.plugins.kotlin.compose) + alias(libs.plugins.kotlin.serialization) } val localProps = Properties().apply { @@ -57,7 +58,12 @@ dependencies { implementation(libs.androidx.core.ktx) implementation(libs.androidx.lifecycle.runtime.ktx) implementation(libs.sentry.android) + implementation(libs.okhttp) + implementation(libs.kotlinx.coroutines.android) + implementation(libs.kotlinx.serialization.json) + implementation(libs.androidx.lifecycle.viewmodel.ktx) testImplementation(libs.junit) + testImplementation(libs.kotlinx.coroutines.test) androidTestImplementation(platform(libs.androidx.compose.bom)) androidTestImplementation(libs.androidx.compose.ui.test.junit4) androidTestImplementation(libs.androidx.espresso.core) diff --git a/mobile/app/src/main/java/com/company/template/websocket/WebSocketManager.kt b/mobile/app/src/main/java/com/company/template/websocket/WebSocketManager.kt new file mode 100644 index 0000000..02c410f --- /dev/null +++ b/mobile/app/src/main/java/com/company/template/websocket/WebSocketManager.kt @@ -0,0 +1,114 @@ +package com.company.template.websocket + +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.Response +import okhttp3.WebSocket +import okhttp3.WebSocketListener + +/** + * Factory abstraction over [OkHttpClient.newWebSocket] — injectable for unit tests. + */ +fun interface WebSocketFactory { + fun newWebSocket(request: Request, listener: WebSocketListener): WebSocket +} + +/** + * Manages a single OkHttp WebSocket connection with exponential-backoff reconnection. + * + * Callbacks fire on OkHttp's dispatcher thread. Callers that update UI state must + * marshal to the main thread (e.g. via StateFlow or viewModelScope). + * + * [reconnectScheduler] is injectable so unit tests can drive retries synchronously. + * The default implementation uses [android.os.Handler] on the main looper. + */ +class WebSocketManager( + private val serverUrl: String, + private val factory: WebSocketFactory = defaultFactory(), + val maxRetries: Int = 10, + private val reconnectScheduler: (delayMs: Long, action: () -> Unit) -> Unit = { delay, action -> + android.os.Handler(android.os.Looper.getMainLooper()).postDelayed(action, delay) + }, +) { + var onOpen: (() -> Unit)? = null + var onClose: (() -> Unit)? = null + var onMessage: ((WsEnvelope) -> Unit)? = null + var onError: ((Throwable) -> Unit)? = null + + private var socket: WebSocket? = null + private var retryCount = 0 + private var currentToken: String? = null + var active = false + private set + + fun connect(token: String?) { + active = true + currentToken = token + retryCount = 0 + openSocket() + } + + fun disconnect() { + active = false + socket?.close(1000, "client disconnect") + socket = null + } + + fun send(envelope: WsEnvelope): Boolean { + val json = Json.encodeToString(envelope) + return socket?.send(json) ?: false + } + + private fun openSocket() { + val urlBuilder = StringBuilder(serverUrl) + currentToken?.let { urlBuilder.append("?token=").append(it) } + + val request = Request.Builder() + .url(urlBuilder.toString()) + .build() + + socket = factory.newWebSocket(request, listener) + } + + private val listener = object : WebSocketListener() { + override fun onOpen(webSocket: WebSocket, response: Response) { + retryCount = 0 + onOpen?.invoke() + } + + override fun onMessage(webSocket: WebSocket, text: String) { + parseEnvelope(text)?.let { this@WebSocketManager.onMessage?.invoke(it) } + } + + override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { + onClose?.invoke() + } + + override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { + onError?.invoke(t) + if (active && retryCount < maxRetries) { + val delay = minOf(1000L * (1L shl retryCount), 30_000L) + retryCount++ + reconnectScheduler(delay) { if (active) openSocket() } + } else { + onClose?.invoke() + } + } + } + + private fun parseEnvelope(text: String): WsEnvelope? = + try { + Json.decodeFromString(text) + } catch (_: Exception) { + null + } + + companion object { + private fun defaultFactory(): WebSocketFactory { + val client = OkHttpClient() + return WebSocketFactory { request, listener -> client.newWebSocket(request, listener) } + } + } +} diff --git a/mobile/app/src/main/java/com/company/template/websocket/WebSocketViewModel.kt b/mobile/app/src/main/java/com/company/template/websocket/WebSocketViewModel.kt new file mode 100644 index 0000000..7f55485 --- /dev/null +++ b/mobile/app/src/main/java/com/company/template/websocket/WebSocketViewModel.kt @@ -0,0 +1,44 @@ +package com.company.template.websocket + +import androidx.lifecycle.ViewModel +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.update + +data class WsState( + val isConnected: Boolean = false, + val lastMessage: WsEnvelope? = null, +) + +class WebSocketViewModel( + private val manager: WebSocketManager, +) : ViewModel() { + + private val _state = MutableStateFlow(WsState()) + val state: StateFlow = _state.asStateFlow() + + init { + manager.onOpen = { _state.update { it.copy(isConnected = true) } } + manager.onClose = { _state.update { it.copy(isConnected = false) } } + manager.onMessage = { envelope -> + _state.update { it.copy(lastMessage = envelope) } + } + } + + fun connect(token: String? = null) { + manager.connect(token) + } + + fun disconnect() { + manager.disconnect() + } + + fun send(envelope: WsEnvelope) { + manager.send(envelope) + } + + public override fun onCleared() { + manager.disconnect() + } +} diff --git a/mobile/app/src/main/java/com/company/template/websocket/WsEnvelope.kt b/mobile/app/src/main/java/com/company/template/websocket/WsEnvelope.kt new file mode 100644 index 0000000..aaec10c --- /dev/null +++ b/mobile/app/src/main/java/com/company/template/websocket/WsEnvelope.kt @@ -0,0 +1,14 @@ +package com.company.template.websocket + +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonElement + +/** + * Typed wire format for all WebSocket messages, mirroring the backend Envelope struct. + * [payload] is a raw [JsonElement] because message types vary; callers decode based on [type]. + */ +@Serializable +data class WsEnvelope( + val type: String, + val payload: JsonElement? = null, +) diff --git a/mobile/app/src/test/java/com/company/template/websocket/FakeOkHttp.kt b/mobile/app/src/test/java/com/company/template/websocket/FakeOkHttp.kt new file mode 100644 index 0000000..0f22268 --- /dev/null +++ b/mobile/app/src/test/java/com/company/template/websocket/FakeOkHttp.kt @@ -0,0 +1,36 @@ +package com.company.template.websocket + +import okhttp3.Request +import okhttp3.WebSocket +import okhttp3.WebSocketListener +import okio.ByteString + +/** + * Test double for [WebSocketFactory]. Captures listener and created sockets so tests + * can drive WebSocket callbacks without real network I/O. + */ +class FakeWebSocketFactory : WebSocketFactory { + val sockets = mutableListOf() + var lastListener: WebSocketListener? = null + + val lastSocket: FakeWebSocket get() = sockets.last() + + override fun newWebSocket(request: Request, listener: WebSocketListener): WebSocket { + lastListener = listener + return FakeWebSocket(request).also { sockets.add(it) } + } +} + +class FakeWebSocket(private val req: Request = Request.Builder().url("ws://localhost/ws").build()) : + WebSocket { + + val sentMessages = mutableListOf() + var closed = false + + override fun request(): Request = req + override fun queueSize(): Long = 0L + override fun send(text: String): Boolean { sentMessages.add(text); return true } + override fun send(bytes: ByteString): Boolean = false + override fun close(code: Int, reason: String?): Boolean { closed = true; return true } + override fun cancel() { closed = true } +} diff --git a/mobile/app/src/test/java/com/company/template/websocket/WebSocketManagerTest.kt b/mobile/app/src/test/java/com/company/template/websocket/WebSocketManagerTest.kt new file mode 100644 index 0000000..463cc18 --- /dev/null +++ b/mobile/app/src/test/java/com/company/template/websocket/WebSocketManagerTest.kt @@ -0,0 +1,140 @@ +package com.company.template.websocket + +import okhttp3.Protocol +import okhttp3.Request +import okhttp3.Response +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertNull +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test + +class WebSocketManagerTest { + + private lateinit var factory: FakeWebSocketFactory + private val scheduledActions = mutableListOf<() -> Unit>() + + private fun makeManager(maxRetries: Int = 3): WebSocketManager = + WebSocketManager( + serverUrl = "ws://localhost:8080/ws", + factory = factory, + maxRetries = maxRetries, + reconnectScheduler = { _, action -> scheduledActions.add(action) }, + ) + + private fun fakeResponse(): Response = + Response.Builder() + .request(Request.Builder().url("ws://localhost:8080/ws").build()) + .protocol(Protocol.HTTP_1_1) + .code(101) + .message("Switching Protocols") + .build() + + @Before + fun setUp() { + factory = FakeWebSocketFactory() + scheduledActions.clear() + } + + @Test + fun `onOpen callback is invoked on successful connection`() { + var opened = false + val manager = makeManager() + manager.onOpen = { opened = true } + + manager.connect(token = null) + factory.lastListener!!.onOpen(factory.lastSocket, fakeResponse()) + + assertTrue(opened) + } + + @Test + fun `onMessage delivers parsed envelope`() { + var received: WsEnvelope? = null + val manager = makeManager() + manager.onMessage = { received = it } + + manager.connect(token = null) + factory.lastListener!!.onMessage( + factory.lastSocket, + """{"type":"job.completed","payload":{"id":"42"}}""", + ) + + assertEquals("job.completed", received?.type) + } + + @Test + fun `onMessage ignores malformed JSON without crashing`() { + var called = false + val manager = makeManager() + manager.onMessage = { called = true } + + manager.connect(token = null) + factory.lastListener!!.onMessage(factory.lastSocket, "not-json") + + assertFalse(called) + } + + @Test + fun `onMessage returns null for missing type field`() { + var received: WsEnvelope? = null + val manager = makeManager() + manager.onMessage = { received = it } + + manager.connect(token = null) + factory.lastListener!!.onMessage(factory.lastSocket, """{"payload":42}""") + + assertNull(received) + } + + @Test + fun `onClose callback fires on graceful close`() { + var closed = false + val manager = makeManager() + manager.onClose = { closed = true } + + manager.connect(token = null) + factory.lastListener!!.onClosed(factory.lastSocket, 1000, "normal") + + assertTrue(closed) + } + + @Test + fun `onFailure schedules retry when retries remain`() { + val manager = makeManager(maxRetries = 2) + manager.connect(token = null) + + factory.lastListener!!.onFailure(factory.lastSocket, RuntimeException("refused"), null) + + assertEquals(1, scheduledActions.size) + } + + @Test + fun `onFailure does not retry after maxRetries exceeded`() { + val manager = makeManager(maxRetries = 1) + manager.connect(token = null) + + factory.lastListener!!.onFailure(factory.lastSocket, RuntimeException("fail"), null) + scheduledActions.last().invoke() + factory.lastListener!!.onFailure(factory.lastSocket, RuntimeException("fail"), null) + + // First failure schedules 1 retry; second failure finds retryCount >= maxRetries. + assertEquals(1, scheduledActions.size) + } + + @Test + fun `disconnect prevents scheduled reconnect from opening new socket`() { + val manager = makeManager(maxRetries = 3) + manager.connect(token = null) + + factory.lastListener!!.onFailure(factory.lastSocket, RuntimeException("fail"), null) + assertEquals(1, scheduledActions.size) + + manager.disconnect() + scheduledActions.first().invoke() + + // After disconnect + action fires, no new socket should be opened. + assertEquals(1, factory.sockets.size) + } +} diff --git a/mobile/app/src/test/java/com/company/template/websocket/WebSocketViewModelTest.kt b/mobile/app/src/test/java/com/company/template/websocket/WebSocketViewModelTest.kt new file mode 100644 index 0000000..ed683c8 --- /dev/null +++ b/mobile/app/src/test/java/com/company/template/websocket/WebSocketViewModelTest.kt @@ -0,0 +1,87 @@ +package com.company.template.websocket + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.UnconfinedTestDispatcher +import kotlinx.coroutines.test.resetMain +import kotlinx.coroutines.test.setMain +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertNull +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test + +@OptIn(ExperimentalCoroutinesApi::class) +class WebSocketViewModelTest { + + private lateinit var factory: FakeWebSocketFactory + private lateinit var manager: WebSocketManager + private lateinit var viewModel: WebSocketViewModel + + @Before + fun setUp() { + Dispatchers.setMain(UnconfinedTestDispatcher()) + factory = FakeWebSocketFactory() + manager = WebSocketManager( + serverUrl = "ws://localhost:8080/ws", + factory = factory, + reconnectScheduler = { _, _ -> }, + ) + viewModel = WebSocketViewModel(manager) + } + + @After + fun tearDown() { + Dispatchers.resetMain() + } + + @Test + fun `initial state has isConnected false and no lastMessage`() { + assertFalse(viewModel.state.value.isConnected) + assertNull(viewModel.state.value.lastMessage) + } + + @Test + fun `connect then open sets isConnected true`() { + viewModel.connect() + manager.onOpen!!.invoke() + + assertTrue(viewModel.state.value.isConnected) + } + + @Test + fun `close event sets isConnected false`() { + viewModel.connect() + manager.onOpen!!.invoke() + manager.onClose!!.invoke() + + assertFalse(viewModel.state.value.isConnected) + } + + @Test + fun `incoming message updates lastMessage`() { + val envelope = WsEnvelope(type = "ping") + viewModel.connect() + manager.onMessage!!.invoke(envelope) + + assertEquals(envelope, viewModel.state.value.lastMessage) + } + + @Test + fun `disconnect calls manager disconnect`() { + viewModel.connect() + viewModel.disconnect() + + assertFalse(manager.active) + } + + @Test + fun `onCleared calls manager disconnect`() { + viewModel.connect() + viewModel.onCleared() + + assertFalse(manager.active) + } +} diff --git a/mobile/gradle/libs.versions.toml b/mobile/gradle/libs.versions.toml index 84bb6ef..2364ad3 100644 --- a/mobile/gradle/libs.versions.toml +++ b/mobile/gradle/libs.versions.toml @@ -9,6 +9,9 @@ activityCompose = "1.8.0" kotlin = "2.2.10" composeBom = "2026.02.01" sentry = "8.14.0" +okhttp = "4.12.0" +kotlinxCoroutines = "1.10.2" +kotlinxSerializationJson = "1.8.1" [libraries] androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } @@ -16,6 +19,7 @@ junit = { group = "junit", name = "junit", version.ref = "junit" } androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "junitVersion" } androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espressoCore" } androidx-lifecycle-runtime-ktx = { group = "androidx.lifecycle", name = "lifecycle-runtime-ktx", version.ref = "lifecycleRuntimeKtx" } +androidx-lifecycle-viewmodel-ktx = { group = "androidx.lifecycle", name = "lifecycle-viewmodel-ktx", version.ref = "lifecycleRuntimeKtx" } androidx-activity-compose = { group = "androidx.activity", name = "activity-compose", version.ref = "activityCompose" } androidx-compose-bom = { group = "androidx.compose", name = "compose-bom", version.ref = "composeBom" } androidx-compose-ui = { group = "androidx.compose.ui", name = "ui" } @@ -26,8 +30,13 @@ androidx-compose-ui-test-manifest = { group = "androidx.compose.ui", name = "ui- androidx-compose-ui-test-junit4 = { group = "androidx.compose.ui", name = "ui-test-junit4" } androidx-compose-material3 = { group = "androidx.compose.material3", name = "material3" } sentry-android = { group = "io.sentry", name = "sentry-android", version.ref = "sentry" } +okhttp = { group = "com.squareup.okhttp3", name = "okhttp", version.ref = "okhttp" } +kotlinx-coroutines-android = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-android", version.ref = "kotlinxCoroutines" } +kotlinx-coroutines-test = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-test", version.ref = "kotlinxCoroutines" } +kotlinx-serialization-json = { group = "org.jetbrains.kotlinx", name = "kotlinx-serialization-json", version.ref = "kotlinxSerializationJson" } [plugins] android-application = { id = "com.android.application", version.ref = "agp" } kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" } +kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } diff --git a/web/docs/_index.md b/web/docs/_index.md index eb6109a..8342d12 100644 --- a/web/docs/_index.md +++ b/web/docs/_index.md @@ -11,3 +11,4 @@ The `docs` agent reads this index first to locate the right file. | Component conventions | [components.md](components.md) | `app/` (all component files) | | Testing patterns | [testing.md](testing.md) | `vitest.config.ts`, `vitest.setup.ts`, `__tests__/page.test.tsx` | | Observability (Sentry error tracking) | [observability.md](observability.md) | `sentry.client.config.ts`, `sentry.server.config.ts`, `sentry.edge.config.ts`, `next.config.ts`, `.env.example` | +| WebSocket hook (useWebSocket, reconnect, auth) | [websocket.md](websocket.md) | `lib/useWebSocket.ts`, `lib/useWebSocket.test.ts` | diff --git a/web/docs/websocket.md b/web/docs/websocket.md new file mode 100644 index 0000000..285781b --- /dev/null +++ b/web/docs/websocket.md @@ -0,0 +1,118 @@ +--- +topic: websocket +last_verified: 2026-06-15 +sources: + - lib/useWebSocket.ts + - lib/useWebSocket.test.ts +--- + +# WebSocket + +## Overview + +The `useWebSocket` hook in `lib/useWebSocket.ts` opens a WebSocket connection to the +backend `GET /ws` endpoint, passes the Firebase ID token as a query parameter, and +reconnects automatically with exponential backoff when the connection drops. + +## Typed message envelope + +```ts +export interface WsEnvelope { + type: string // dot-separated event name, e.g. "job.completed" + payload: unknown // shape determined by type; cast on the receiving end +} +``` + +## useWebSocket hook + +```ts +import { useWebSocket } from '@/lib/useWebSocket' + +const { isConnected, lastMessage, send } = useWebSocket({ + url: 'ws://localhost:8080/ws', + token: firebaseIdToken, // optional — required by backend in staging/prod + onMessage: (envelope) => { // optional callback for each message + if (envelope.type === 'job.completed') { /* ... */ } + }, + maxRetries: 10, // optional, default 10 +}) +``` + +### Options + +| Option | Type | Default | Description | +|---|---|---|---| +| `url` | `string` | required | WebSocket URL (no token suffix needed) | +| `token` | `string` | `undefined` | Firebase ID token — appended as `?token=` | +| `onMessage` | `(e: WsEnvelope) => void` | `undefined` | Called for each valid message | +| `maxRetries` | `number` | `10` | Maximum reconnection attempts | + +### Return value + +| Field | Type | Description | +|---|---|---| +| `isConnected` | `boolean` | `true` while the socket is open | +| `lastMessage` | `WsEnvelope \| null` | Most recently received envelope | +| `send` | `(e: WsEnvelope) => void` | Send a message; no-op when disconnected | + +## Reconnection + +After a disconnect, the hook retries with exponential backoff capped at 30 seconds: + +``` +delay = min(1000ms × 2^retryCount, 30 000ms) +``` + +Retry count resets to 0 on a successful reconnect. No further retries are attempted +after `maxRetries` failures. Pending retry timers are cleared on unmount. + +## Authentication + +Append the Firebase ID token as a query parameter: + +```ts +useWebSocket({ url: 'ws://localhost:8080/ws', token: idToken }) +// → connects to ws://localhost:8080/ws?token= +``` + +The backend (`GET /ws`) rejects connections without a valid token with HTTP 401 before +the WebSocket upgrade completes. + +## `"use client"` requirement + +`useWebSocket` is a Client Component hook — it requires `"use client"` and must not +be imported from Server Components. Wrap it in a Client Component that receives +the token as a prop from a Server Component. + +```tsx +// app/live/LiveFeed.tsx (Server Component — fetches token server-side) +import LiveFeedClient from './LiveFeedClient' +export default async function LiveFeed() { + const token = await getFirebaseToken() + return +} + +// app/live/LiveFeedClient.tsx +'use client' +import { useWebSocket } from '@/lib/useWebSocket' +export default function LiveFeedClient({ token }: { token: string }) { + const { isConnected, lastMessage } = useWebSocket({ url: process.env.NEXT_PUBLIC_WS_URL!, token }) + // ... +} +``` + +## Testing + +Tests live in `lib/useWebSocket.test.ts` and use Vitest + `@testing-library/react` +with a `MockWebSocket` class injected via `vi.stubGlobal('WebSocket', MockWebSocket)`. +Fake timers (`vi.useFakeTimers`) drive reconnection delays synchronously. + +Covered cases: +- Correct URL construction (with and without token) +- `isConnected` state transitions on open/close +- `onMessage` callback and `lastMessage` state update +- Non-JSON message tolerance +- Exponential backoff reconnection +- `maxRetries` limit +- Cleanup on unmount (close + timer clear) +- `send()` delegates to `WebSocket.send()` diff --git a/web/lib/useWebSocket.test.ts b/web/lib/useWebSocket.test.ts new file mode 100644 index 0000000..8bcb065 --- /dev/null +++ b/web/lib/useWebSocket.test.ts @@ -0,0 +1,149 @@ +import { renderHook, act } from '@testing-library/react' +import { beforeEach, afterEach, describe, it, expect, vi } from 'vitest' +import { useWebSocket } from './useWebSocket' + +// Minimal WebSocket mock that captures handlers and lets tests trigger them. +class MockWebSocket { + static CONNECTING = 0 + static OPEN = 1 + static CLOSING = 2 + static CLOSED = 3 + + readyState = MockWebSocket.OPEN + url: string + onopen: ((e: Event) => void) | null = null + onmessage: ((e: MessageEvent) => void) | null = null + onclose: ((e: CloseEvent) => void) | null = null + onerror: ((e: Event) => void) | null = null + + static instances: MockWebSocket[] = [] + + constructor(url: string) { + this.url = url + MockWebSocket.instances.push(this) + } + + send = vi.fn() + close = vi.fn(() => { + this.readyState = MockWebSocket.CLOSED + this.onclose?.(new CloseEvent('close')) + }) + + // Test helpers + simulateOpen() { + this.readyState = MockWebSocket.OPEN + this.onopen?.(new Event('open')) + } + simulateMessage(data: string) { + this.onmessage?.(new MessageEvent('message', { data })) + } + simulateClose() { + this.readyState = MockWebSocket.CLOSED + this.onclose?.(new CloseEvent('close')) + } +} + +beforeEach(() => { + MockWebSocket.instances = [] + vi.useFakeTimers() + vi.stubGlobal('WebSocket', MockWebSocket) +}) + +afterEach(() => { + vi.useRealTimers() + vi.unstubAllGlobals() +}) + +describe('useWebSocket', () => { + it('connects to the correct URL without a token', () => { + renderHook(() => useWebSocket({ url: 'ws://localhost:8080/ws' })) + expect(MockWebSocket.instances).toHaveLength(1) + expect(MockWebSocket.instances[0].url).toBe('ws://localhost:8080/ws') + }) + + it('appends token as query param when provided', () => { + renderHook(() => useWebSocket({ url: 'ws://localhost:8080/ws', token: 'abc123' })) + expect(MockWebSocket.instances[0].url).toBe('ws://localhost:8080/ws?token=abc123') + }) + + it('sets isConnected true after open', () => { + const { result } = renderHook(() => useWebSocket({ url: 'ws://localhost:8080/ws' })) + act(() => MockWebSocket.instances[0].simulateOpen()) + expect(result.current.isConnected).toBe(true) + }) + + it('sets isConnected false after close', () => { + const { result } = renderHook(() => useWebSocket({ url: 'ws://localhost:8080/ws' })) + act(() => MockWebSocket.instances[0].simulateOpen()) + act(() => MockWebSocket.instances[0].simulateClose()) + expect(result.current.isConnected).toBe(false) + }) + + it('calls onMessage and updates lastMessage when a message arrives', () => { + const onMessage = vi.fn() + const { result } = renderHook(() => + useWebSocket({ url: 'ws://localhost:8080/ws', onMessage }), + ) + const data = JSON.stringify({ type: 'ping', payload: null }) + act(() => MockWebSocket.instances[0].simulateMessage(data)) + + expect(onMessage).toHaveBeenCalledWith({ type: 'ping', payload: null }) + expect(result.current.lastMessage).toEqual({ type: 'ping', payload: null }) + }) + + it('ignores non-JSON messages without throwing', () => { + const onMessage = vi.fn() + renderHook(() => useWebSocket({ url: 'ws://localhost:8080/ws', onMessage })) + expect(() => + act(() => MockWebSocket.instances[0].simulateMessage('not json')), + ).not.toThrow() + expect(onMessage).not.toHaveBeenCalled() + }) + + it('reconnects after close with exponential backoff', () => { + renderHook(() => useWebSocket({ url: 'ws://localhost:8080/ws' })) + expect(MockWebSocket.instances).toHaveLength(1) + + act(() => MockWebSocket.instances[0].simulateClose()) + // First retry: 1000ms * 2^0 = 1000ms + act(() => vi.advanceTimersByTime(1000)) + expect(MockWebSocket.instances).toHaveLength(2) + }) + + it('does not reconnect beyond maxRetries', () => { + renderHook(() => useWebSocket({ url: 'ws://localhost:8080/ws', maxRetries: 1 })) + + // First close triggers retry 0 → opens connection 2 + act(() => MockWebSocket.instances[0].simulateClose()) + act(() => vi.advanceTimersByTime(1000)) + expect(MockWebSocket.instances).toHaveLength(2) + + // Second close — retry count reached, no more reconnects + act(() => MockWebSocket.instances[1].simulateClose()) + act(() => vi.advanceTimersByTime(30000)) + expect(MockWebSocket.instances).toHaveLength(2) + }) + + it('closes the WebSocket and clears timers on unmount', () => { + const { unmount } = renderHook(() => useWebSocket({ url: 'ws://localhost:8080/ws' })) + const ws = MockWebSocket.instances[0] + + // Trigger a pending retry before unmounting + act(() => ws.simulateClose()) + unmount() + + // No new connection should appear after timers fire + act(() => vi.advanceTimersByTime(30000)) + expect(MockWebSocket.instances).toHaveLength(1) + expect(ws.close).toHaveBeenCalled() + }) + + it('send() sends JSON when the socket is open', () => { + const { result } = renderHook(() => useWebSocket({ url: 'ws://localhost:8080/ws' })) + const ws = MockWebSocket.instances[0] + act(() => ws.simulateOpen()) + + act(() => result.current.send({ type: 'ping', payload: null })) + expect(ws.send).toHaveBeenCalledWith(JSON.stringify({ type: 'ping', payload: null })) + }) +}) diff --git a/web/lib/useWebSocket.ts b/web/lib/useWebSocket.ts new file mode 100644 index 0000000..a70a5d4 --- /dev/null +++ b/web/lib/useWebSocket.ts @@ -0,0 +1,108 @@ +'use client' + +import { useEffect, useLayoutEffect, useRef, useState, useCallback } from 'react' + +export interface WsEnvelope { + type: string + payload: unknown +} + +export interface UseWebSocketOptions { + url: string + token?: string + onMessage?: (envelope: WsEnvelope) => void + maxRetries?: number +} + +export interface UseWebSocketReturn { + isConnected: boolean + lastMessage: WsEnvelope | null + send: (envelope: WsEnvelope) => void +} + +export function useWebSocket({ + url, + token, + onMessage, + maxRetries = 10, +}: UseWebSocketOptions): UseWebSocketReturn { + const [isConnected, setIsConnected] = useState(false) + const [lastMessage, setLastMessage] = useState(null) + + const wsRef = useRef(null) + const retriesRef = useRef(0) + const retryTimerRef = useRef | null>(null) + const unmountedRef = useRef(false) + + // Refs holding the latest option values and connect closure. + // Updated in useLayoutEffect (after render, before effects) — never during render. + const onMessageRef = useRef(onMessage) + const maxRetriesRef = useRef(maxRetries) + const connectRef = useRef<() => void>(() => {}) + + useLayoutEffect(() => { + onMessageRef.current = onMessage + maxRetriesRef.current = maxRetries + }) + + useLayoutEffect(() => { + connectRef.current = () => { + if (unmountedRef.current) return + + const wsUrl = token ? `${url}?token=${encodeURIComponent(token)}` : url + const ws = new WebSocket(wsUrl) + wsRef.current = ws + + ws.onopen = () => { + if (unmountedRef.current) return + setIsConnected(true) + retriesRef.current = 0 + } + + ws.onmessage = (event: MessageEvent) => { + if (unmountedRef.current) return + try { + const envelope = JSON.parse(event.data as string) as WsEnvelope + setLastMessage(envelope) + onMessageRef.current?.(envelope) + } catch { + // Ignore non-JSON frames. + } + } + + ws.onclose = () => { + if (unmountedRef.current) return + setIsConnected(false) + if (retriesRef.current < maxRetriesRef.current) { + const delay = Math.min(1000 * 2 ** retriesRef.current, 30_000) + retriesRef.current++ + retryTimerRef.current = setTimeout(() => connectRef.current(), delay) + } + } + + ws.onerror = () => ws.close() + } + }) + + useEffect(() => { + unmountedRef.current = false + connectRef.current() + + return () => { + unmountedRef.current = true + if (retryTimerRef.current !== null) { + clearTimeout(retryTimerRef.current) + retryTimerRef.current = null + } + wsRef.current?.close() + } + }, []) + + const send = useCallback((envelope: WsEnvelope) => { + if (wsRef.current?.readyState === WebSocket.OPEN) { + wsRef.current.send(JSON.stringify(envelope)) + } + }, []) + + return { isConnected, lastMessage, send } +} From c5bee92f297b90ed2fe0efe5aea0a9bd6fdfbc70 Mon Sep 17 00:00:00 2001 From: GRACENOBLE Date: Mon, 15 Jun 2026 16:51:45 +0300 Subject: [PATCH 2/2] fix(websocket): address CodeRabbit review findings - hub: make Publish non-blocking (select/default) to prevent worker stall - client: use non-blocking Unregister send to prevent goroutine leak on shutdown - client: replace //nolint:errcheck with explicit error checks throughout ReadPump/WritePump - ws_handler: restrict CheckOrigin to BLUEPRINT_WS_ALLOWED_ORIGIN in non-debug mode - WebSocketManager: close existing socket before opening on repeated connect() calls - useWebSocket: handle pre-existing query params when appending ?token= (use & separator) - tests: add regression test for URL with existing query params - docs: add language identifier to unlabeled fenced code blocks (markdownlint MD040) - .env.example: document BLUEPRINT_WS_ALLOWED_ORIGIN Co-Authored-By: Claude Sonnet 4.6 --- backend/.env.example | 5 ++- backend/docs/websocket.md | 4 +-- backend/internal/infrastructure/ws/client.go | 33 ++++++++++++++----- backend/internal/infrastructure/ws/hub.go | 9 +++-- .../internal/transport/handlers/ws_handler.go | 15 +++++++-- .../template/websocket/WebSocketManager.kt | 2 ++ web/docs/websocket.md | 2 +- web/lib/useWebSocket.test.ts | 7 ++++ web/lib/useWebSocket.ts | 4 ++- 9 files changed, 64 insertions(+), 17 deletions(-) diff --git a/backend/.env.example b/backend/.env.example index c19c32f..2a7ac38 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -15,4 +15,7 @@ FIREBASE_PROJECT_ID= # Service account key as a single-line JSON string. Get from: Firebase Console → Project Settings → Service Accounts → Generate new private key FIREBASE_SERVICE_ACCOUNT_JSON= # Sentry error tracking DSN (optional; leave empty to disable) -SENTRY_DSN= \ No newline at end of file +SENTRY_DSN= +# WebSocket allowed origin in staging/production — e.g. https://example.com +# Leave empty to deny all cross-origin WebSocket connections in non-debug mode. +BLUEPRINT_WS_ALLOWED_ORIGIN=http://localhost:3000 \ No newline at end of file diff --git a/backend/docs/websocket.md b/backend/docs/websocket.md index d2e8fb9..c7cf044 100644 --- a/backend/docs/websocket.md +++ b/backend/docs/websocket.md @@ -62,7 +62,7 @@ Each WebSocket connection spawns two goroutines: `ReadPump` and `WritePump` (on The Hub serialises all mutations (register / unregister / broadcast) through a `select` loop so no locking is needed on its internal `clients` map. -``` +```text caller goroutine │ hub.Publish(...) ▼ @@ -95,7 +95,7 @@ Ping/pong keepalive: `pingPeriod = 54s`, `pongWait = 60s`, `writeWait = 10s`. ## Route — GET /ws -``` +```text GET /ws?token= ``` diff --git a/backend/internal/infrastructure/ws/client.go b/backend/internal/infrastructure/ws/client.go index e39effa..08b7318 100644 --- a/backend/internal/infrastructure/ws/client.go +++ b/backend/internal/infrastructure/ws/client.go @@ -38,11 +38,18 @@ func NewClient(hub *Hub, conn *websocket.Conn) *Client { // It unregisters the client and closes the connection when done. func (c *Client) ReadPump() { defer func() { - c.hub.Unregister <- c + // Non-blocking: if Hub.Run has already exited (ctx cancelled) there is no + // receiver on Unregister. The Hub already closed c.Send in that case. + select { + case c.hub.Unregister <- c: + default: + } c.conn.Close() }() c.conn.SetReadLimit(maxMessageSize) - c.conn.SetReadDeadline(time.Now().Add(pongWait)) //nolint:errcheck + if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { + return + } c.conn.SetPongHandler(func(string) error { return c.conn.SetReadDeadline(time.Now().Add(pongWait)) }) @@ -68,26 +75,36 @@ func (c *Client) WritePump() { for { select { case msg, ok := <-c.Send: - c.conn.SetWriteDeadline(time.Now().Add(writeWait)) //nolint:errcheck + if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { + return + } if !ok { - c.conn.WriteMessage(websocket.CloseMessage, []byte{}) //nolint:errcheck + _ = c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } w, err := c.conn.NextWriter(websocket.TextMessage) if err != nil { return } - w.Write(msg) //nolint:errcheck + if _, err := w.Write(msg); err != nil { + return + } // Drain any queued messages into the same WebSocket frame. for i := len(c.Send); i > 0; i-- { - w.Write([]byte{'\n'}) //nolint:errcheck - w.Write(<-c.Send) //nolint:errcheck + if _, err := w.Write([]byte{'\n'}); err != nil { + break + } + if _, err := w.Write(<-c.Send); err != nil { + break + } } if err := w.Close(); err != nil { return } case <-ticker.C: - c.conn.SetWriteDeadline(time.Now().Add(writeWait)) //nolint:errcheck + if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { + return + } if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } diff --git a/backend/internal/infrastructure/ws/hub.go b/backend/internal/infrastructure/ws/hub.go index 0497294..756648a 100644 --- a/backend/internal/infrastructure/ws/hub.go +++ b/backend/internal/infrastructure/ws/hub.go @@ -3,6 +3,7 @@ package ws import ( "context" "encoding/json" + "errors" ) // Hub maintains the set of active WebSocket clients and broadcasts messages to them. @@ -67,6 +68,10 @@ func (h *Hub) Publish(msgType string, payload any) error { if err != nil { return err } - h.broadcast <- env - return nil + select { + case h.broadcast <- env: + return nil + default: + return errors.New("ws: broadcast channel full, message dropped") + } } diff --git a/backend/internal/transport/handlers/ws_handler.go b/backend/internal/transport/handlers/ws_handler.go index 66aa030..0913dfe 100644 --- a/backend/internal/transport/handlers/ws_handler.go +++ b/backend/internal/transport/handlers/ws_handler.go @@ -2,6 +2,7 @@ package handlers import ( "net/http" + "os" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -12,8 +13,18 @@ import ( var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, - // TODO: restrict to known origins in production. - CheckOrigin: func(r *http.Request) bool { return true }, + // In debug mode (local dev) all origins are allowed so browsers can connect + // from any port. In staging/production, restrict to BLUEPRINT_WS_ALLOWED_ORIGIN. + CheckOrigin: func(r *http.Request) bool { + if gin.IsDebugging() { + return true + } + allowed := os.Getenv("BLUEPRINT_WS_ALLOWED_ORIGIN") + if allowed == "" { + return false + } + return r.Header.Get("Origin") == allowed + }, } // WsHandler godoc diff --git a/mobile/app/src/main/java/com/company/template/websocket/WebSocketManager.kt b/mobile/app/src/main/java/com/company/template/websocket/WebSocketManager.kt index 02c410f..f249f90 100644 --- a/mobile/app/src/main/java/com/company/template/websocket/WebSocketManager.kt +++ b/mobile/app/src/main/java/com/company/template/websocket/WebSocketManager.kt @@ -44,6 +44,8 @@ class WebSocketManager( private set fun connect(token: String?) { + socket?.close(1000, "reconnecting") + socket = null active = true currentToken = token retryCount = 0 diff --git a/web/docs/websocket.md b/web/docs/websocket.md index 285781b..fd89486 100644 --- a/web/docs/websocket.md +++ b/web/docs/websocket.md @@ -59,7 +59,7 @@ const { isConnected, lastMessage, send } = useWebSocket({ After a disconnect, the hook retries with exponential backoff capped at 30 seconds: -``` +```text delay = min(1000ms × 2^retryCount, 30 000ms) ``` diff --git a/web/lib/useWebSocket.test.ts b/web/lib/useWebSocket.test.ts index 8bcb065..5c028fa 100644 --- a/web/lib/useWebSocket.test.ts +++ b/web/lib/useWebSocket.test.ts @@ -66,6 +66,13 @@ describe('useWebSocket', () => { expect(MockWebSocket.instances[0].url).toBe('ws://localhost:8080/ws?token=abc123') }) + it('appends token with & when url already has query params', () => { + renderHook(() => + useWebSocket({ url: 'ws://localhost:8080/ws?room=1', token: 'abc123' }), + ) + expect(MockWebSocket.instances[0].url).toBe('ws://localhost:8080/ws?room=1&token=abc123') + }) + it('sets isConnected true after open', () => { const { result } = renderHook(() => useWebSocket({ url: 'ws://localhost:8080/ws' })) act(() => MockWebSocket.instances[0].simulateOpen()) diff --git a/web/lib/useWebSocket.ts b/web/lib/useWebSocket.ts index a70a5d4..23aa808 100644 --- a/web/lib/useWebSocket.ts +++ b/web/lib/useWebSocket.ts @@ -49,7 +49,9 @@ export function useWebSocket({ connectRef.current = () => { if (unmountedRef.current) return - const wsUrl = token ? `${url}?token=${encodeURIComponent(token)}` : url + const wsUrl = token + ? `${url}${url.includes('?') ? '&' : '?'}token=${encodeURIComponent(token)}` + : url const ws = new WebSocket(wsUrl) wsRef.current = ws