diff --git a/sdk/go/discovery/discovery.go b/sdk/go/discovery/discovery.go index 125dccf..0cda387 100644 --- a/sdk/go/discovery/discovery.go +++ b/sdk/go/discovery/discovery.go @@ -10,6 +10,7 @@ import ( "log/slog" "net/http" "net/url" + "strconv" "strings" "sync" "time" @@ -36,7 +37,8 @@ import ( // then error. Not cached. // // NOTE: a Resolver is safe for concurrent use; the in-process memo is -// guarded by a mutex. +// guarded by a mutex and concurrent Resolve calls for the same address +// are coalesced into a single probe (single-flight). type Resolver struct { fileCache *cache httpc *http.Client @@ -44,6 +46,22 @@ type Resolver struct { mu sync.Mutex memCache map[string]NodeInfo + inflight map[string]*resolveCall +} + +// resolveCall is the shared state of a single in-flight Resolve for one +// address. +// The elected goroutine creates it, runs the disk-cache check and +// network probe, and closes done; every concurrent caller for the same +// address reads info/err once done is closed and skips the probe. +// NOTE: info, err, and waiters must only be read or written with the +// owning Resolver's mu held until done is closed; after close, info and +// err are immutable and safe to read without the lock. +type resolveCall struct { + done chan struct{} + info NodeInfo + err error + waiters int } // New constructs a Resolver from functional options. @@ -71,11 +89,16 @@ func New(opts ...Option) (*Resolver, error) { httpc: o.httpClient, logger: o.logger, memCache: map[string]NodeInfo{}, + inflight: map[string]*resolveCall{}, }, nil } // Resolve returns the NodeInfo for addr. // Trailing slashes in addr are normalized away before lookup. +// Concurrent callers for the same address are coalesced: the first +// caller probes, every other caller waits for that probe's result and +// sees the same NodeInfo (or the same error). +// A failed probe is not memoized; the next caller retries the network. func (r *Resolver) Resolve(ctx context.Context, addr string) (NodeInfo, error) { addr = strings.TrimRight(addr, "/") @@ -84,11 +107,45 @@ func (r *Resolver) Resolve(ctx context.Context, addr string) (NodeInfo, error) { r.mu.Unlock() return info, nil } + if call, ok := r.inflight[addr]; ok { + call.waiters++ + r.mu.Unlock() + select { + case <-call.done: + return call.info, call.err + case <-ctx.Done(): + return NodeInfo{}, ctx.Err() + } + } + call := &resolveCall{done: make(chan struct{})} + r.inflight[addr] = call + r.mu.Unlock() + + info, err := r.resolveUncached(ctx, addr) + + r.mu.Lock() + call.info = info + call.err = err + if err == nil { + r.memCache[addr] = info + } + delete(r.inflight, addr) + // Close before releasing the lock so a new caller arriving on the + // same addr cannot register a fresh probe while existing waiters + // are still blocked on this call's done channel. + close(call.done) r.mu.Unlock() + return info, err +} + +// resolveUncached runs the disk-cache check and network probe for addr. +// NOTE: callers must hold the single-flight election for addr; +// this helper does no in-process memoization and is not safe to call +// directly outside the Resolve flow. +func (r *Resolver) resolveUncached(ctx context.Context, addr string) (NodeInfo, error) { if r.fileCache != nil { if info, ok := r.fileCache.get(addr); ok { - r.cacheInMemory(addr, info) return info, nil } } @@ -107,17 +164,9 @@ func (r *Resolver) Resolve(ctx context.Context, addr string) (NodeInfo, error) { r.logger.Warn("discovery: cache write failed", "addr", addr, "err", err) } } - r.cacheInMemory(addr, info) return info, nil } -// cacheInMemory records a resolved NodeInfo for addr in the in-process memo. -func (r *Resolver) cacheInMemory(addr string, info NodeInfo) { - r.mu.Lock() - r.memCache[addr] = info - r.mu.Unlock() -} - // fetchWithRetry issues a GET to u and retries up to attempts times on // transport errors and 5xx responses, with a short linear backoff // between attempts. @@ -254,6 +303,16 @@ func validate(info NodeInfo) error { if parsed.Host == "" { return fmt.Errorf("api_base_url %q is missing a host", info.APIBaseURL) } + // url.Parse already rejects non-numeric port segments, but it + // accepts numeric ports outside the uint16 range (e.g. + // "https://host:99999/"). + // Reject those here so the failure surfaces as a discovery-domain + // message rather than an opaque transport error. + if port := parsed.Port(); port != "" { + if _, err := strconv.ParseUint(port, 10, 16); err != nil { + return fmt.Errorf("api_base_url %q has an invalid port %q", info.APIBaseURL, port) + } + } if info.APIVersion == "" { return errors.New("api_version is required") } diff --git a/sdk/go/discovery/discovery_test.go b/sdk/go/discovery/discovery_test.go index cdfd003..f24c7e1 100644 --- a/sdk/go/discovery/discovery_test.go +++ b/sdk/go/discovery/discovery_test.go @@ -9,6 +9,8 @@ import ( "os" "path/filepath" "strings" + "sync" + "sync/atomic" "testing" "time" @@ -368,3 +370,271 @@ func TestResolveTrimsTrailingSlashInAddr(t *testing.T) { require.NoError(t, err) require.Equal(t, srv.URL+DefaultAPIPath, info.APIBaseURL) } + +func TestResolveErrorsOnNonNumericPortInAPIBaseURL(t *testing.T) { + t.Parallel() + + // url.Parse accepts a non-numeric port without complaint, so without + // explicit validation the failure surfaces later as a transport + // error. + // Pin that the discovery layer rejects it up front with a + // domain-y message that names the offending field. + var calls int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "version": 1, + "api_base_url": "https://example.com:bad/api/v1", + "api_version": "v1" + }`)) + })) + defer srv.Close() + + r := newTestResolver(t) + _, err := r.Resolve(context.Background(), srv.URL) + require.Error(t, err) + require.Contains(t, strings.ToLower(err.Error()), "port") + + // A failed validation is not memoized, so a second call probes + // the network again rather than returning a cached error. + _, err = r.Resolve(context.Background(), srv.URL) + require.Error(t, err) + require.Equal(t, 2, calls) +} + +func TestResolveErrorsOnOutOfRangePortInAPIBaseURL(t *testing.T) { + t.Parallel() + + // url.Parse accepts numeric ports outside the uint16 range, so + // without the strconv.ParseUint check in validate() the failure + // would surface later as an opaque transport error. + // Pin that an out-of-range numeric port is rejected at validation + // time with a domain-y message naming the offending field. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "version": 1, + "api_base_url": "https://example.com:99999/api/v1", + "api_version": "v1" + }`)) + })) + defer srv.Close() + + _, err := newTestResolver(t).Resolve(context.Background(), srv.URL) + require.Error(t, err) + require.Contains(t, strings.ToLower(err.Error()), "port") +} + +// waitForInflightWaiter blocks until at least n goroutines have registered as +// waiters on the in-flight Resolve for addr. +// The elected prober does not count toward n. +// Used only by single-flight tests to remove the "did the second +// caller register yet" race without exposing internal state to +// production callers. +func (r *Resolver) waitForInflightWaiter(t *testing.T, addr string, n int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for { + r.mu.Lock() + call, ok := r.inflight[addr] + waiters := 0 + if ok { + waiters = call.waiters + } + r.mu.Unlock() + if waiters >= n { + return + } + if time.Now().After(deadline) { + t.Fatalf("timed out waiting for %d in-flight waiter(s) on %s (have %d, registered=%v)", n, addr, waiters, ok) + } + time.Sleep(time.Millisecond) + } +} + +func TestResolverCoalescesConcurrentResolvesForSameAddr(t *testing.T) { + t.Parallel() + + // Block the elected prober inside the handler until both callers + // have registered. Single-flight should collapse the two Resolve + // calls into a single HTTP probe and deliver the same NodeInfo to + // each caller. + var calls int32 + gate := make(chan struct{}) + var gateOnce sync.Once + releaseGate := func() { gateOnce.Do(func() { close(gate) }) } + // Register cleanup so an assertion failure inside the test cannot + // leave the handler blocked on gate and hang srv.Close(). + t.Cleanup(releaseGate) + arrived := make(chan struct{}, 1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&calls, 1) + select { + case arrived <- struct{}{}: + default: + } + <-gate + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "version": 1, + "api_base_url": "https://api.example.com/api/v1", + "api_version": "v1" + }`)) + })) + defer srv.Close() + + r := newTestResolver(t) + + type result struct { + info NodeInfo + err error + } + out := make(chan result, 2) + + go func() { + info, err := r.Resolve(context.Background(), srv.URL) + out <- result{info, err} + }() + <-arrived + go func() { + info, err := r.Resolve(context.Background(), srv.URL) + out <- result{info, err} + }() + r.waitForInflightWaiter(t, srv.URL, 1) + releaseGate() + + r1 := <-out + r2 := <-out + require.NoError(t, r1.err) + require.NoError(t, r2.err) + require.Equal(t, r1.info, r2.info) + require.Equal(t, int32(1), atomic.LoadInt32(&calls)) +} + +func TestResolverProbesIndependentAddrsConcurrently(t *testing.T) { + t.Parallel() + + // Two different addresses must not coalesce; each gets its own + // probe. + // httptest binds a single host, so we differentiate the two + // "addresses" by the addr argument's trailing path segment. + // The single-flight key is the full normalized addr, so this is + // enough to keep them distinct in-process. + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "application/json") + if strings.HasPrefix(req.URL.Path, "/one") { + _, _ = w.Write([]byte(`{ + "version": 1, + "api_base_url": "https://api-one.example.com/api/v1", + "api_version": "v1" + }`)) + return + } + _, _ = w.Write([]byte(`{ + "version": 1, + "api_base_url": "https://api-two.example.com/api/v1", + "api_version": "v1" + }`)) + })) + defer srv.Close() + + addrOne := srv.URL + "/one" + addrTwo := srv.URL + "/two" + + r := newTestResolver(t) + + type result struct { + info NodeInfo + err error + } + out := make(chan result, 2) + go func() { + info, err := r.Resolve(context.Background(), addrOne) + out <- result{info, err} + }() + go func() { + info, err := r.Resolve(context.Background(), addrTwo) + out <- result{info, err} + }() + + got := map[string]NodeInfo{} + for range 2 { + res := <-out + require.NoError(t, res.err) + got[res.info.APIBaseURL] = res.info + } + require.Equal(t, int32(2), atomic.LoadInt32(&calls)) + require.Contains(t, got, "https://api-one.example.com/api/v1") + require.Contains(t, got, "https://api-two.example.com/api/v1") +} + +func TestResolverPropagatesProbeFailureToAllWaiters(t *testing.T) { + t.Parallel() + + // An elected probe that errors out must propagate the same error + // to every concurrent waiter, with exactly one HTTP call charged. + var calls int32 + gate := make(chan struct{}) + var gateOnce sync.Once + releaseGate := func() { gateOnce.Do(func() { close(gate) }) } + // Register cleanup so an assertion failure inside the test cannot + // leave the handler blocked on gate and hang srv.Close(). + t.Cleanup(releaseGate) + arrived := make(chan struct{}, 1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&calls, 1) + select { + case arrived <- struct{}{}: + default: + } + <-gate + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(`...`)) + })) + defer srv.Close() + + r := newTestResolver(t) + errs := make(chan error, 2) + + go func() { + _, err := r.Resolve(context.Background(), srv.URL) + errs <- err + }() + <-arrived + go func() { + _, err := r.Resolve(context.Background(), srv.URL) + errs <- err + }() + r.waitForInflightWaiter(t, srv.URL, 1) + releaseGate() + + e1 := <-errs + e2 := <-errs + require.Error(t, e1) + require.Error(t, e2) + require.Equal(t, int32(1), atomic.LoadInt32(&calls)) +} + +func TestResolverDoesNotMemoizeFailedProbes(t *testing.T) { + t.Parallel() + + // A failed probe must not memoize; the next caller must retry the + // network. Two sequential failing calls means two HTTP probes. + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(`...`)) + })) + defer srv.Close() + + r := newTestResolver(t) + _, err := r.Resolve(context.Background(), srv.URL) + require.Error(t, err) + _, err = r.Resolve(context.Background(), srv.URL) + require.Error(t, err) + require.Equal(t, int32(2), atomic.LoadInt32(&calls)) +}