Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions proxy/server/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,14 @@ func (s *TargetStream) CloseWith(err error) {
}

// Send the supplied request to the target stream, returning
// an error if the context has already been cancelled.
// an error if the stream's context has already been cancelled
// (e.g. due to a dial failure to the target).
func (s *TargetStream) Send(req proto.Message) error {
ctx := s.getStream().Context()
select {
case s.reqChan <- req:
return nil
case <-ctx.Done():
return ctx.Err()
case <-s.ctx.Done():
return s.ctx.Err()
}
}

Expand Down Expand Up @@ -219,8 +219,10 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) {
s.grpcConn = grpcConn
grpcStream, err := s.grpcConn.NewStream(ctx, s.serviceMethod.StreamDesc(), s.serviceMethod.FullName())
if err != nil {
// We cannot create a new stream to the target. So we need to cancel this stream.
// We cannot create a new stream to the target. Cancel the
// stream context so that any in-flight Send() calls unblock.
s.logger.Info("unable to create stream", "status", err)
s.cancelFunc()
return fmt.Errorf("could not connect to target from the proxy: %w", err)
}

Expand Down
93 changes: 92 additions & 1 deletion proxy/server/target_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
"google.golang.org/protobuf/types/known/anypb"

pb "github.com/Snowflake-Labs/sansshell/proxy"
_ "github.com/Snowflake-Labs/sansshell/proxy/testdata"
td "github.com/Snowflake-Labs/sansshell/proxy/testdata"
"github.com/Snowflake-Labs/sansshell/testing/testutil"
)

Expand Down Expand Up @@ -188,6 +188,97 @@ func TestTargetStreamAddNonBlocking(t *testing.T) {
}
}

// TestSendUnblocksWhenTargetUnreachable verifies that TargetStream.Send does
// not hang when the target is unreachable and the reqChan buffer is full.
// Before the fix, Send selected on the gRPC stream's context (which stayed
// alive because the stream hadn't been established yet). After the fix, Send
// selects on the TargetStream's own context, which is cancelled when the dial
// or NewStream fails.
func TestSendUnblocksWhenTargetUnreachable(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

serviceMap := LoadGlobalServiceMap()
ss := NewTargetStreamSet(serviceMap, blockingClientDialer{}, nil)

replyChan := make(chan *pb.ProxyReply, 100)
doneChan := make(chan uint64, 1)

req := &pb.StartStream{
Target: "unreachable:9500",
Nonce: 42,
MethodName: "/Testdata.TestService/TestClientStream",
}
if err := ss.Add(ctx, req, replyChan, doneChan); err != nil {
t.Fatalf("Add: %v", err)
}

// The reply should contain a stream ID (not an error) because dial
// is non-blocking; the actual connection is attempted in Run().
var streamID uint64
select {
case msg := <-replyChan:
sid := msg.GetStartStreamReply().GetStreamId()
if sid == 0 {
t.Fatalf("expected stream ID, got: %+v", msg)
}
streamID = sid
case <-time.After(2 * time.Second):
t.Fatal("no reply from Add")
}

payload, err := anypb.New(&td.TestRequest{Input: "chunk"})
if err != nil {
t.Fatal(err)
}
data := &pb.StreamData{
StreamIds: []uint64{streamID},
Payload: payload,
}

// Fill the reqChan buffer.
for i := 0; i < ReqBufferSize; i++ {
if err := ss.Send(ctx, data); err != nil {
t.Fatalf("Send[%d]: %v", i, err)
}
}

// The next Send would block forever with the old code because:
// - reqChan is full
// - blockingClientConn.NewStream blocks, so Run hasn't started
// consuming reqChan yet
// - the old Send checked the gRPC stream context (unconnectedClientStream),
// which is still alive
//
// After the fix, Send checks s.ctx. When we cancel the parent context,
// s.ctx (derived from it) is also cancelled, and Send unblocks.
sendDone := make(chan error, 1)
go func() {
sendDone <- ss.Send(ctx, data)
}()

// Give it a moment to confirm it's actually blocked.
select {
case err := <-sendDone:
t.Fatalf("Send returned immediately (want block): %v", err)
case <-time.After(100 * time.Millisecond):
// expected: Send is blocked on full buffer
}

// Cancel the context — this should unblock Send promptly.
cancel()

select {
case err := <-sendDone:
if err == nil {
t.Fatal("Send returned nil after cancel, want context error")
}
t.Logf("Send unblocked with: %v (good)", err)
case <-time.After(2 * time.Second):
t.Fatal("Send still blocked 2s after context cancel — BUG: stall not fixed")
}
}

func TestIsCardinalityViolation(t *testing.T) {
for _, tc := range []struct {
name string
Expand Down
Loading