Skip to content
137 changes: 137 additions & 0 deletions pkg/authserver/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/ory/fosite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"

servercrypto "github.com/stacklok/toolhive/pkg/authserver/server/crypto"
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
Expand All @@ -36,6 +37,9 @@ const (
testIssuer = "http://localhost"
testAudience = "https://mcp.example.com"

testConfidentialClientID = "test-confidential-client"
testConfidentialClientSecret = "test-confidential-secret"

// testAccessTokenLifetime is the configured access token lifetime in setupTestServer.
testAccessTokenLifetime = time.Hour
)
Expand All @@ -51,6 +55,7 @@ type testServerOptions struct {
upstream upstream.OAuth2Provider
scopes []string
accessTokenLifespan time.Duration
confidentialClient bool
}

// testServerOption is a functional option for test server setup.
Expand All @@ -77,6 +82,13 @@ func withAccessTokenLifespan(d time.Duration) testServerOption {
}
}

// withConfidentialClient registers a confidential client for client_credentials testing.
func withConfidentialClient() testServerOption {
return func(opts *testServerOptions) {
opts.confidentialClient = true
}
}

// testKeyProvider is a simple KeyProvider for tests that uses a pre-generated RSA key.
type testKeyProvider struct {
key *rsa.PrivateKey
Expand Down Expand Up @@ -138,6 +150,22 @@ func setupTestServer(t *testing.T, opts ...testServerOption) *testServer {
})
require.NoError(t, err)

// Optionally register a confidential client for client_credentials testing
if options.confidentialClient {
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(testConfidentialClientSecret), bcrypt.DefaultCost)
require.NoError(t, err)
err = stor.RegisterClient(ctx, &fosite.DefaultClient{
ID: testConfidentialClientID,
Secret: hashedSecret,
RedirectURIs: nil,
ResponseTypes: nil,
GrantTypes: []string{"client_credentials"},
Scopes: []string{"openid"},
Public: false,
})
require.NoError(t, err)
}

// 5. Build upstream config for newServer
// When no upstream is provided, use a dummy config that satisfies validation
// Note: Uses HTTPS to pass config validation
Expand Down Expand Up @@ -1162,3 +1190,112 @@ func TestIntegration_RefreshToken_ShortLivedAccessToken(t *testing.T) {
require.True(t, ok)
assert.Greater(t, int64(exp), time.Now().Unix(), "refreshed token exp must be in the future")
}

// ============================================================================
// Client Credentials Flow Integration Tests
// ============================================================================

// TestIntegration_ClientCredentials_BasicFlow tests the client_credentials grant flow.
func TestIntegration_ClientCredentials_BasicFlow(t *testing.T) {
t.Parallel()

ts := setupTestServer(t, withConfidentialClient())

// Make client_credentials token request with client_secret_post
params := url.Values{
"grant_type": {"client_credentials"},
"client_id": {testConfidentialClientID},
"client_secret": {testConfidentialClientSecret},
"scope": {"openid"},
}

resp := makeTokenRequest(t, ts.Server.URL, params)
defer resp.Body.Close()

require.Equal(t, http.StatusOK, resp.StatusCode, "client_credentials request should succeed")
tokenData := parseTokenResponse(t, resp)

// Verify access token is present
accessToken, ok := tokenData["access_token"].(string)
require.True(t, ok, "access_token should be a string")
require.NotEmpty(t, accessToken)

// Verify token_type
tokenType, ok := tokenData["token_type"].(string)
require.True(t, ok)
assert.Equal(t, "bearer", strings.ToLower(tokenType))

// Verify no refresh token (client_credentials should not issue refresh tokens)
_, hasRefresh := tokenData["refresh_token"]
assert.False(t, hasRefresh, "client_credentials should not issue refresh_token")

// Verify JWT claims
parsedToken, err := jwt.ParseSigned(accessToken, []jose.SignatureAlgorithm{jose.RS256})
require.NoError(t, err)

var claims map[string]interface{}
err = parsedToken.Claims(ts.PrivateKey.Public(), &claims)
require.NoError(t, err)

// Subject should be the client ID for M2M tokens
assert.Equal(t, testConfidentialClientID, claims["sub"], "sub should be client ID for M2M tokens")
assert.Equal(t, testConfidentialClientID, claims["client_id"], "client_id claim should match")
assert.Equal(t, testIssuer, claims["iss"], "issuer should match")
}

// TestIntegration_ClientCredentials_WithAudience tests client_credentials with RFC 8707 resource parameter.
func TestIntegration_ClientCredentials_WithAudience(t *testing.T) {
t.Parallel()

ts := setupTestServer(t, withConfidentialClient())

params := url.Values{
"grant_type": {"client_credentials"},
"client_id": {testConfidentialClientID},
"client_secret": {testConfidentialClientSecret},
"scope": {"openid"},
"resource": {testAudience},
}

resp := makeTokenRequest(t, ts.Server.URL, params)
defer resp.Body.Close()

require.Equal(t, http.StatusOK, resp.StatusCode)
tokenData := parseTokenResponse(t, resp)

accessToken, ok := tokenData["access_token"].(string)
require.True(t, ok)

parsedToken, err := jwt.ParseSigned(accessToken, []jose.SignatureAlgorithm{jose.RS256})
require.NoError(t, err)

var claims map[string]interface{}
err = parsedToken.Claims(ts.PrivateKey.Public(), &claims)
require.NoError(t, err)

// Verify audience from resource parameter
aud, ok := claims["aud"].([]interface{})
require.True(t, ok, "aud should be an array")
require.Len(t, aud, 1)
assert.Equal(t, testAudience, aud[0], "audience should match requested resource")
}

// TestIntegration_ClientCredentials_WrongSecret tests that wrong secrets are rejected.
func TestIntegration_ClientCredentials_WrongSecret(t *testing.T) {
t.Parallel()

ts := setupTestServer(t, withConfidentialClient())

params := url.Values{
"grant_type": {"client_credentials"},
"client_id": {testConfidentialClientID},
"client_secret": {"wrong-secret"},
}

resp := makeTokenRequest(t, ts.Server.URL, params)
defer resp.Body.Close()

require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
errResp := parseTokenResponse(t, resp)
assert.Equal(t, "invalid_client", errResp["error"])
}
13 changes: 11 additions & 2 deletions pkg/authserver/server/handlers/dcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,22 @@ func (h *Handler) RegisterClientHandler(w http.ResponseWriter, req *http.Request
return
}

// Generate client ID
// Determine if this is a confidential client registration.
isConfidential := validated.TokenEndpointAuthMethod != "none"

// Generate client credentials.
clientID := uuid.NewString()
var clientSecret string
if isConfidential {
clientSecret = uuid.NewString() // Secure random secret
}

// Create fosite client using factory.
fositeClient, err := registration.New(registration.Config{
ID: clientID,
Secret: clientSecret,
RedirectURIs: validated.RedirectURIs,
Public: true,
Public: !isConfidential,
GrantTypes: validated.GrantTypes,
ResponseTypes: validated.ResponseTypes,
Scopes: scopes,
Expand Down Expand Up @@ -106,6 +114,7 @@ func (h *Handler) RegisterClientHandler(w http.ResponseWriter, req *http.Request
// the client know exactly which scopes it can request.
response := registration.DCRResponse{
ClientID: clientID,
ClientSecret: clientSecret, // Only set for confidential clients
ClientIDIssuedAt: time.Now().Unix(),
RedirectURIs: validated.RedirectURIs,
ClientName: validated.ClientName,
Expand Down
9 changes: 7 additions & 2 deletions pkg/authserver/server/handlers/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,15 @@ func (h *Handler) buildOAuthMetadata() sharedobauth.AuthorizationServerMetadata
// OPTIONAL
GrantTypesSupported: []string{
string(fosite.GrantTypeAuthorizationCode),
string(fosite.GrantTypeClientCredentials),
string(fosite.GrantTypeRefreshToken),
},
CodeChallengeMethodsSupported: []string{crypto.PKCEChallengeMethodS256},
TokenEndpointAuthMethodsSupported: []string{sharedobauth.TokenEndpointAuthMethodNone},
CodeChallengeMethodsSupported: []string{crypto.PKCEChallengeMethodS256},
TokenEndpointAuthMethodsSupported: []string{
sharedobauth.TokenEndpointAuthMethodNone,
"client_secret_basic",
"client_secret_post",
},
}
}

Expand Down
6 changes: 6 additions & 0 deletions pkg/authserver/server/handlers/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,11 @@ func TestOAuthDiscoveryHandler(t *testing.T) {
// Verify OPTIONAL fields per RFC 8414
assert.Contains(t, metadata.GrantTypesSupported, "authorization_code")
assert.Contains(t, metadata.GrantTypesSupported, "refresh_token")
assert.Contains(t, metadata.GrantTypesSupported, "client_credentials")
assert.Contains(t, metadata.CodeChallengeMethodsSupported, "S256")
assert.Contains(t, metadata.TokenEndpointAuthMethodsSupported, "none")
assert.Contains(t, metadata.TokenEndpointAuthMethodsSupported, "client_secret_basic")
assert.Contains(t, metadata.TokenEndpointAuthMethodsSupported, "client_secret_post")
}

func TestOAuthDiscoveryHandler_DoesNotContainOIDCFields(t *testing.T) {
Expand Down Expand Up @@ -228,8 +231,11 @@ func TestOIDCDiscoveryHandler(t *testing.T) {
// Verify OPTIONAL fields
assert.Contains(t, discovery.GrantTypesSupported, "authorization_code")
assert.Contains(t, discovery.GrantTypesSupported, "refresh_token")
assert.Contains(t, discovery.GrantTypesSupported, "client_credentials")
assert.Contains(t, discovery.CodeChallengeMethodsSupported, "S256")
assert.Contains(t, discovery.TokenEndpointAuthMethodsSupported, "none")
assert.Contains(t, discovery.TokenEndpointAuthMethodsSupported, "client_secret_basic")
assert.Contains(t, discovery.TokenEndpointAuthMethodsSupported, "client_secret_post")
}

// TODO: Add tests for TokenHandler once implemented:
Expand Down
27 changes: 24 additions & 3 deletions pkg/authserver/server/handlers/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/ory/fosite/compose"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/crypto/bcrypt"

"github.com/stacklok/toolhive/pkg/authserver/server"
servercrypto "github.com/stacklok/toolhive/pkg/authserver/server/crypto"
Expand All @@ -29,6 +30,11 @@ const (
testInternalState = "internal-state-123"
)

const (
testConfidentialClientID = "test-confidential-client"
testConfidentialClientSecret = "test-secret-12345"
)

// mockIDPProvider implements upstream.OAuth2Provider for testing.
type mockIDPProvider struct {
providerType upstream.ProviderType
Expand Down Expand Up @@ -150,14 +156,28 @@ func handlerTestSetup(t *testing.T) (*Handler, *testStorageState, *mockIDPProvid
}
storState.clients[testAuthClientID] = testClient

// Setup mock expectations for GetClient
stor.EXPECT().GetClient(gomock.Any(), testAuthClientID).DoAndReturn(func(_ context.Context, id string) (fosite.Client, error) {
// Register a confidential test client (for client_credentials)
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(testConfidentialClientSecret), bcrypt.DefaultCost)
require.NoError(t, err)

confidentialClient := &fosite.DefaultClient{
ID: testConfidentialClientID,
Secret: hashedSecret,
RedirectURIs: nil,
ResponseTypes: nil,
GrantTypes: []string{"client_credentials"},
Scopes: []string{"openid"},
Public: false,
}
storState.clients[testConfidentialClientID] = confidentialClient

// Setup mock expectations for GetClient — return any registered client
stor.EXPECT().GetClient(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, id string) (fosite.Client, error) {
if c, ok := storState.clients[id]; ok {
return c, nil
}
return nil, fosite.ErrNotFound
}).AnyTimes()
stor.EXPECT().GetClient(gomock.Any(), gomock.Not(testAuthClientID)).Return(nil, fosite.ErrNotFound).AnyTimes()

// Setup mock expectations for pending authorization storage
stor.EXPECT().StorePendingAuthorization(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
Expand Down Expand Up @@ -317,6 +337,7 @@ func handlerTestSetup(t *testing.T) (*Handler, *testStorageState, *mockIDPProvid
stor,
&compose.CommonStrategy{CoreStrategy: jwtStrategy},
compose.OAuth2AuthorizeExplicitFactory,
compose.OAuth2ClientCredentialsGrantFactory,
compose.OAuth2RefreshTokenGrantFactory,
compose.OAuth2PKCEFactory,
)
Expand Down
8 changes: 8 additions & 0 deletions pkg/authserver/server/handlers/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ func (h *Handler) TokenHandler(w http.ResponseWriter, req *http.Request) {
return
}

// For client_credentials grant, fosite uses the placeholder session directly
// (there's no stored authorize session to retrieve). We must populate the
// session's subject with the client ID so the JWT has a meaningful "sub" claim.
if accessRequest.GetGrantTypes().ExactOne("client_credentials") {
clientID := accessRequest.GetClient().GetID()
accessRequest.SetSession(session.New(clientID, "", clientID))
}

// RFC 8707: Handle resource parameter for audience claim.
// The resource parameter allows clients to specify which protected resource (MCP server)
// the token is intended for. This value becomes the "aud" claim in the JWT.
Expand Down
Loading