From 502df8a63feef4b08cacbd40cbf2e84252ddd1d1 Mon Sep 17 00:00:00 2001 From: Ihar Kryvanos Date: Fri, 10 Apr 2026 17:05:31 +0200 Subject: [PATCH] Fix proxy hang when target is unreachable during client-streaming RPCs TargetStream.Send selected on the gRPC stream context to detect cancellation, but before the stream is established (during dial/NewStream), that context belongs to the placeholder unconnectedClientStream and never gets cancelled on connection failure. This caused Send to block forever when reqChan was full and the target was unreachable. Two changes fix this: 1. Send now selects on s.ctx (the TargetStream's own context) instead of the gRPC stream context. s.ctx is derived from the caller's context and is cancelled when the parent context is cancelled or when Run() fails. 2. Run() now calls s.cancelFunc() when NewStream fails. Previously only DialContext failure cancelled the stream context; a NewStream failure left in-flight Send() calls blocked indefinitely. Together these ensure that when a target is unreachable, the proxy propagates the error back to the client instead of hanging. Made-with: Cursor --- proxy/server/target.go | 12 +++-- proxy/server/target_test.go | 93 ++++++++++++++++++++++++++++++++++++- 2 files changed, 99 insertions(+), 6 deletions(-) diff --git a/proxy/server/target.go b/proxy/server/target.go index b85e2f8a..712b07fb 100644 --- a/proxy/server/target.go +++ b/proxy/server/target.go @@ -176,14 +176,14 @@ func (s *TargetStream) CloseWith(err error) { } // Send the supplied request to the target stream, returning -// an error if the context has already been cancelled. +// an error if the stream's context has already been cancelled +// (e.g. due to a dial failure to the target). func (s *TargetStream) Send(req proto.Message) error { - ctx := s.getStream().Context() select { case s.reqChan <- req: return nil - case <-ctx.Done(): - return ctx.Err() + case <-s.ctx.Done(): + return s.ctx.Err() } } @@ -219,8 +219,10 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) { s.grpcConn = grpcConn grpcStream, err := s.grpcConn.NewStream(ctx, s.serviceMethod.StreamDesc(), s.serviceMethod.FullName()) if err != nil { - // We cannot create a new stream to the target. So we need to cancel this stream. + // We cannot create a new stream to the target. Cancel the + // stream context so that any in-flight Send() calls unblock. s.logger.Info("unable to create stream", "status", err) + s.cancelFunc() return fmt.Errorf("could not connect to target from the proxy: %w", err) } diff --git a/proxy/server/target_test.go b/proxy/server/target_test.go index 61cc7805..eed68690 100644 --- a/proxy/server/target_test.go +++ b/proxy/server/target_test.go @@ -29,7 +29,7 @@ import ( "google.golang.org/protobuf/types/known/anypb" pb "github.com/Snowflake-Labs/sansshell/proxy" - _ "github.com/Snowflake-Labs/sansshell/proxy/testdata" + td "github.com/Snowflake-Labs/sansshell/proxy/testdata" "github.com/Snowflake-Labs/sansshell/testing/testutil" ) @@ -188,6 +188,97 @@ func TestTargetStreamAddNonBlocking(t *testing.T) { } } +// TestSendUnblocksWhenTargetUnreachable verifies that TargetStream.Send does +// not hang when the target is unreachable and the reqChan buffer is full. +// Before the fix, Send selected on the gRPC stream's context (which stayed +// alive because the stream hadn't been established yet). After the fix, Send +// selects on the TargetStream's own context, which is cancelled when the dial +// or NewStream fails. +func TestSendUnblocksWhenTargetUnreachable(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serviceMap := LoadGlobalServiceMap() + ss := NewTargetStreamSet(serviceMap, blockingClientDialer{}, nil) + + replyChan := make(chan *pb.ProxyReply, 100) + doneChan := make(chan uint64, 1) + + req := &pb.StartStream{ + Target: "unreachable:9500", + Nonce: 42, + MethodName: "/Testdata.TestService/TestClientStream", + } + if err := ss.Add(ctx, req, replyChan, doneChan); err != nil { + t.Fatalf("Add: %v", err) + } + + // The reply should contain a stream ID (not an error) because dial + // is non-blocking; the actual connection is attempted in Run(). + var streamID uint64 + select { + case msg := <-replyChan: + sid := msg.GetStartStreamReply().GetStreamId() + if sid == 0 { + t.Fatalf("expected stream ID, got: %+v", msg) + } + streamID = sid + case <-time.After(2 * time.Second): + t.Fatal("no reply from Add") + } + + payload, err := anypb.New(&td.TestRequest{Input: "chunk"}) + if err != nil { + t.Fatal(err) + } + data := &pb.StreamData{ + StreamIds: []uint64{streamID}, + Payload: payload, + } + + // Fill the reqChan buffer. + for i := 0; i < ReqBufferSize; i++ { + if err := ss.Send(ctx, data); err != nil { + t.Fatalf("Send[%d]: %v", i, err) + } + } + + // The next Send would block forever with the old code because: + // - reqChan is full + // - blockingClientConn.NewStream blocks, so Run hasn't started + // consuming reqChan yet + // - the old Send checked the gRPC stream context (unconnectedClientStream), + // which is still alive + // + // After the fix, Send checks s.ctx. When we cancel the parent context, + // s.ctx (derived from it) is also cancelled, and Send unblocks. + sendDone := make(chan error, 1) + go func() { + sendDone <- ss.Send(ctx, data) + }() + + // Give it a moment to confirm it's actually blocked. + select { + case err := <-sendDone: + t.Fatalf("Send returned immediately (want block): %v", err) + case <-time.After(100 * time.Millisecond): + // expected: Send is blocked on full buffer + } + + // Cancel the context — this should unblock Send promptly. + cancel() + + select { + case err := <-sendDone: + if err == nil { + t.Fatal("Send returned nil after cancel, want context error") + } + t.Logf("Send unblocked with: %v (good)", err) + case <-time.After(2 * time.Second): + t.Fatal("Send still blocked 2s after context cancel — BUG: stall not fixed") + } +} + func TestIsCardinalityViolation(t *testing.T) { for _, tc := range []struct { name string