diff --git a/proxy/server/target.go b/proxy/server/target.go index b85e2f8a..712b07fb 100644 --- a/proxy/server/target.go +++ b/proxy/server/target.go @@ -176,14 +176,14 @@ func (s *TargetStream) CloseWith(err error) { } // Send the supplied request to the target stream, returning -// an error if the context has already been cancelled. +// an error if the stream's context has already been cancelled +// (e.g. due to a dial failure to the target). func (s *TargetStream) Send(req proto.Message) error { - ctx := s.getStream().Context() select { case s.reqChan <- req: return nil - case <-ctx.Done(): - return ctx.Err() + case <-s.ctx.Done(): + return s.ctx.Err() } } @@ -219,8 +219,10 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) { s.grpcConn = grpcConn grpcStream, err := s.grpcConn.NewStream(ctx, s.serviceMethod.StreamDesc(), s.serviceMethod.FullName()) if err != nil { - // We cannot create a new stream to the target. So we need to cancel this stream. + // We cannot create a new stream to the target. Cancel the + // stream context so that any in-flight Send() calls unblock. s.logger.Info("unable to create stream", "status", err) + s.cancelFunc() return fmt.Errorf("could not connect to target from the proxy: %w", err) } diff --git a/proxy/server/target_test.go b/proxy/server/target_test.go index 61cc7805..eed68690 100644 --- a/proxy/server/target_test.go +++ b/proxy/server/target_test.go @@ -29,7 +29,7 @@ import ( "google.golang.org/protobuf/types/known/anypb" pb "github.com/Snowflake-Labs/sansshell/proxy" - _ "github.com/Snowflake-Labs/sansshell/proxy/testdata" + td "github.com/Snowflake-Labs/sansshell/proxy/testdata" "github.com/Snowflake-Labs/sansshell/testing/testutil" ) @@ -188,6 +188,97 @@ func TestTargetStreamAddNonBlocking(t *testing.T) { } } +// TestSendUnblocksWhenTargetUnreachable verifies that TargetStream.Send does +// not hang when the target is unreachable and the reqChan buffer is full. +// Before the fix, Send selected on the gRPC stream's context (which stayed +// alive because the stream hadn't been established yet). After the fix, Send +// selects on the TargetStream's own context, which is cancelled when the dial +// or NewStream fails. +func TestSendUnblocksWhenTargetUnreachable(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serviceMap := LoadGlobalServiceMap() + ss := NewTargetStreamSet(serviceMap, blockingClientDialer{}, nil) + + replyChan := make(chan *pb.ProxyReply, 100) + doneChan := make(chan uint64, 1) + + req := &pb.StartStream{ + Target: "unreachable:9500", + Nonce: 42, + MethodName: "/Testdata.TestService/TestClientStream", + } + if err := ss.Add(ctx, req, replyChan, doneChan); err != nil { + t.Fatalf("Add: %v", err) + } + + // The reply should contain a stream ID (not an error) because dial + // is non-blocking; the actual connection is attempted in Run(). + var streamID uint64 + select { + case msg := <-replyChan: + sid := msg.GetStartStreamReply().GetStreamId() + if sid == 0 { + t.Fatalf("expected stream ID, got: %+v", msg) + } + streamID = sid + case <-time.After(2 * time.Second): + t.Fatal("no reply from Add") + } + + payload, err := anypb.New(&td.TestRequest{Input: "chunk"}) + if err != nil { + t.Fatal(err) + } + data := &pb.StreamData{ + StreamIds: []uint64{streamID}, + Payload: payload, + } + + // Fill the reqChan buffer. + for i := 0; i < ReqBufferSize; i++ { + if err := ss.Send(ctx, data); err != nil { + t.Fatalf("Send[%d]: %v", i, err) + } + } + + // The next Send would block forever with the old code because: + // - reqChan is full + // - blockingClientConn.NewStream blocks, so Run hasn't started + // consuming reqChan yet + // - the old Send checked the gRPC stream context (unconnectedClientStream), + // which is still alive + // + // After the fix, Send checks s.ctx. When we cancel the parent context, + // s.ctx (derived from it) is also cancelled, and Send unblocks. + sendDone := make(chan error, 1) + go func() { + sendDone <- ss.Send(ctx, data) + }() + + // Give it a moment to confirm it's actually blocked. + select { + case err := <-sendDone: + t.Fatalf("Send returned immediately (want block): %v", err) + case <-time.After(100 * time.Millisecond): + // expected: Send is blocked on full buffer + } + + // Cancel the context — this should unblock Send promptly. + cancel() + + select { + case err := <-sendDone: + if err == nil { + t.Fatal("Send returned nil after cancel, want context error") + } + t.Logf("Send unblocked with: %v (good)", err) + case <-time.After(2 * time.Second): + t.Fatal("Send still blocked 2s after context cancel — BUG: stall not fixed") + } +} + func TestIsCardinalityViolation(t *testing.T) { for _, tc := range []struct { name string