Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.TypeSpec.Generator.Providers;
using Microsoft.TypeSpec.Generator.Snippets;
using Microsoft.TypeSpec.Generator.Statements;
using Microsoft.TypeSpec.Generator.Shared;
using Microsoft.TypeSpec.Generator.Utilities;
using static Microsoft.TypeSpec.Generator.Snippets.Snippet;

Expand Down Expand Up @@ -123,7 +124,7 @@ private static bool UseSingletonInstance(InputClient inputClient)

internal IReadOnlyDictionary<EnumProvider, PropertyProvider>? VersionProperties => field ??= BuildVersionProperties();

private Dictionary<EnumProvider, PropertyProvider>? BuildVersionProperties()
private Dictionary<EnumProvider, PropertyProvider>? BuildVersionProperties()
{
if (_serviceVersionsEnums is null)
{
Expand All @@ -133,11 +134,26 @@ private static bool UseSingletonInstance(InputClient inputClient)
var properties = new Dictionary<EnumProvider, PropertyProvider>(_serviceVersionsEnums.Count);
foreach (var (inputEnum, enumProvider) in _serviceVersionsEnums)
{
// For multi-service clients, use the full namespace to guarantee uniqueness
// (the last segment alone can collide when services share a namespace).
var versionPropertyName = _inputClient.IsMultiServiceClient
? $"{inputEnum.Namespace.ToIdentifierName()}{ApiVersionSuffix}"
: VersionSuffix;
string versionPropertyName;
if (!_inputClient.IsMultiServiceClient)
{
versionPropertyName = VersionSuffix;
}
else
{
var serviceNamespace = inputEnum.Namespace;
if (!string.IsNullOrEmpty(serviceNamespace) &&
ClientHelper.HasLastSegmentCollision(serviceNamespace, inputEnum, _serviceVersionsEnums.Keys))
{
// Last segment collides — find the shortest unique namespace suffix.
string uniquePrefix = ClientHelper.GetShortestUniqueNamespacePrefix(serviceNamespace, inputEnum, _serviceVersionsEnums.Keys);
versionPropertyName = $"{uniquePrefix.ToIdentifierName()}{ApiVersionSuffix}";
}
else
{
versionPropertyName = ClientHelper.BuildNameForService(serviceNamespace ?? string.Empty, string.Empty, ApiVersionSuffix);
}
}

var versionProperty = new PropertyProvider(
null,
Expand All @@ -161,11 +177,28 @@ private static bool UseSingletonInstance(InputClient inputClient)
}

Dictionary<FieldProvider, EnumProvider> latestVersionFields = new(_serviceVersionsEnums.Count);
foreach (var enumProvider in _serviceVersionsEnums.Values)
foreach (var (inputEnum, enumProvider) in _serviceVersionsEnums)
{
var fieldName = _inputClient.IsMultiServiceClient
? $"{LatestPrefix}{enumProvider.Name.ToIdentifierName()}"
: LatestVersionFieldName;
string fieldName;
if (!_inputClient.IsMultiServiceClient)
{
fieldName = LatestVersionFieldName;
}
else
{
var serviceNamespace = inputEnum.Namespace;
if (!string.IsNullOrEmpty(serviceNamespace) &&
ClientHelper.HasLastSegmentCollision(serviceNamespace, inputEnum, _serviceVersionsEnums.Keys))
{
// Last segment collides — find the shortest unique namespace suffix.
string uniquePrefix = ClientHelper.GetShortestUniqueNamespacePrefix(serviceNamespace, inputEnum, _serviceVersionsEnums.Keys);
fieldName = $"{LatestPrefix}{uniquePrefix.ToIdentifierName()}{VersionSuffix}";
}
else
{
fieldName = ClientHelper.BuildNameForService(serviceNamespace ?? string.Empty, LatestPrefix, VersionSuffix);
}
}
var field = new FieldProvider(
modifiers: FieldModifiers.Private | FieldModifiers.Const,
type: enumProvider.Type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,163 @@ public void MultiServiceClient_SameLastSegment_ProducesUniqueVersionEnums()
var nestedTypes = clientOptionsProvider!.NestedTypes;
Assert.AreEqual(2, nestedTypes.Count);
CollectionAssert.AllItemsAreUnique(nestedTypes.Select(t => t.Name).ToList());

var writer = new TypeProviderWriter(clientOptionsProvider!);
var file = writer.Write();

Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

[Test]
public void MultiServiceClient_UniqueNamespaces_ProducesUniqueVersionEnums()
{
List<string> serviceOneVersions = ["2024-01-01"];
List<string> serviceTwoVersions = ["2024-06-01"];

var serviceOneEnumValues = serviceOneVersions.Select(a => (a, a));
var serviceTwoEnumValues = serviceTwoVersions.Select(a => (a, a));

var serviceOneEnum = InputFactory.StringEnum(
"Versions",
serviceOneEnumValues,
usage: InputModelTypeUsage.ApiVersionEnum,
clientNamespace: "ServiceOne");
var serviceTwoEnum = InputFactory.StringEnum(
"Versions",
serviceTwoEnumValues,
usage: InputModelTypeUsage.ApiVersionEnum,
clientNamespace: "ServiceTwo");

InputParameter apiVersionParameter = InputFactory.QueryParameter(
"apiVersion",
InputPrimitiveType.String,
isRequired: true,
scope: InputParameterScope.Client,
isApiVersion: true);

var serviceOneOperation = InputFactory.Operation(
"ServiceOneOperation",
parameters: [apiVersionParameter],
ns: "ServiceOne");

var serviceTwoOperation = InputFactory.Operation(
"ServiceTwoOperation",
parameters: [apiVersionParameter],
ns: "ServiceTwo");

var client = InputFactory.Client(
"MultiServiceClient",
methods:
[
InputFactory.BasicServiceMethod("ServiceOneMethod", serviceOneOperation),
InputFactory.BasicServiceMethod("ServiceTwoMethod", serviceTwoOperation)
],
parameters: [apiVersionParameter],
isMultiServiceClient: true);

MockHelpers.LoadMockGenerator(
apiVersions: () => [.. serviceOneVersions, .. serviceTwoVersions],
clients: () => [client],
inputEnums: () => [serviceOneEnum, serviceTwoEnum]);

var clientProvider = ScmCodeModelGenerator.Instance.TypeFactory.CreateClient(client);
Assert.IsNotNull(clientProvider);

// Validate that Fields access does not crash (the original issue crashed here)
Assert.DoesNotThrow(() => _ = clientProvider!.Fields);

// Validate that Methods access does not crash (original crash site: Fields.ToDictionary in BuildMethods)
Assert.DoesNotThrow(() => _ = clientProvider!.Methods);

var clientOptionsProvider = clientProvider?.ClientOptions;
Assert.IsNotNull(clientOptionsProvider);

// Validate nested service version enums have unique names
var nestedTypes = clientOptionsProvider!.NestedTypes;
Assert.AreEqual(2, nestedTypes.Count);
CollectionAssert.AllItemsAreUnique(nestedTypes.Select(t => t.Name).ToList());

// Verify enum names follow the XServiceVersion pattern
Assert.AreEqual("ServiceOneServiceVersion", nestedTypes[0].Name);
Assert.AreEqual("ServiceTwoServiceVersion", nestedTypes[1].Name);

var writer = new TypeProviderWriter(clientOptionsProvider!);
var file = writer.Write();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

[Test]
public void MultiServiceClient_SameNamespace_ProducesUniqueVersionEnums()
{
// Regression test for the scenario where both enums share the exact same namespace
// (e.g., when tspconfig remaps both services to the same C# output namespace).
List<string> serviceOneVersions = ["2024-01-01"];
List<string> serviceTwoVersions = ["2024-06-01"];

var serviceOneEnumValues = serviceOneVersions.Select(a => (a, a));
var serviceTwoEnumValues = serviceTwoVersions.Select(a => (a, a));

// Both enums have the EXACT SAME namespace (simulates tspconfig namespace override)
var serviceOneEnum = InputFactory.StringEnum(
"ServiceOneVersions",
serviceOneEnumValues,
usage: InputModelTypeUsage.ApiVersionEnum,
clientNamespace: "Azure.Generator.MgmtTypeSpec.MultiService.Tests");
var serviceTwoEnum = InputFactory.StringEnum(
"ServiceTwoVersions",
serviceTwoEnumValues,
usage: InputModelTypeUsage.ApiVersionEnum,
clientNamespace: "Azure.Generator.MgmtTypeSpec.MultiService.Tests");

InputParameter apiVersionParameter = InputFactory.QueryParameter(
"apiVersion",
InputPrimitiveType.String,
isRequired: true,
scope: InputParameterScope.Client,
isApiVersion: true);

var serviceOneOperation = InputFactory.Operation(
"ServiceOneOperation",
parameters: [apiVersionParameter],
ns: "Azure.Generator.MgmtTypeSpec.MultiService.Tests");

var serviceTwoOperation = InputFactory.Operation(
"ServiceTwoOperation",
parameters: [apiVersionParameter],
ns: "Azure.Generator.MgmtTypeSpec.MultiService.Tests");

var client = InputFactory.Client(
"MultiServiceClient",
methods:
[
InputFactory.BasicServiceMethod("ServiceOneMethod", serviceOneOperation),
InputFactory.BasicServiceMethod("ServiceTwoMethod", serviceTwoOperation)
],
parameters: [apiVersionParameter],
isMultiServiceClient: true);

MockHelpers.LoadMockGenerator(
apiVersions: () => [.. serviceOneVersions, .. serviceTwoVersions],
clients: () => [client],
inputEnums: () => [serviceOneEnum, serviceTwoEnum]);

var clientProvider = ScmCodeModelGenerator.Instance.TypeFactory.CreateClient(client);
Assert.IsNotNull(clientProvider);

Assert.DoesNotThrow(() => _ = clientProvider!.Fields);
Assert.DoesNotThrow(() => _ = clientProvider!.Methods);

var clientOptionsProvider = clientProvider?.ClientOptions;
Assert.IsNotNull(clientOptionsProvider);

// Validate nested service version enums have unique names
var nestedTypes = clientOptionsProvider!.NestedTypes;
Assert.AreEqual(2, nestedTypes.Count);
CollectionAssert.AllItemsAreUnique(nestedTypes.Select(t => t.Name).ToList());

var writer = new TypeProviderWriter(clientOptionsProvider!);
var file = writer.Write();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

[Test]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3678,12 +3678,12 @@ public void GetApiVersionFieldForService_MultiService_ReturnsMatchingField()
// Should return the matching field for ServiceA
var fieldA = clientProvider!.GetApiVersionFieldForService("Sample.ServiceA");
Assert.IsNotNull(fieldA);
Assert.AreEqual("_sampleServiceAApiVersion", fieldA!.Name);
Assert.AreEqual("_serviceAApiVersion", fieldA!.Name);

// Should return the matching field for ServiceB
var fieldB = clientProvider.GetApiVersionFieldForService("Sample.ServiceB");
Assert.IsNotNull(fieldB);
Assert.AreEqual("_sampleServiceBApiVersion", fieldB!.Name);
Assert.AreEqual("_serviceBApiVersion", fieldB!.Name);
}

[Test]
Expand Down Expand Up @@ -3814,11 +3814,11 @@ public void GetApiVersionFieldForService_MultiService_CaseInsensitiveMatch()
// Should match case-insensitively
var fieldLowerCase = clientProvider!.GetApiVersionFieldForService("sample.serviceA");
Assert.IsNotNull(fieldLowerCase);
Assert.AreEqual("_sampleServiceAApiVersion", fieldLowerCase!.Name);
Assert.AreEqual("_serviceAApiVersion", fieldLowerCase!.Name);

var fieldUpperCase = clientProvider.GetApiVersionFieldForService("SAMPLE.SERVICEa");
Assert.IsNotNull(fieldUpperCase);
Assert.AreEqual("_sampleServiceAApiVersion", fieldUpperCase!.Name);
Assert.AreEqual("_serviceAApiVersion", fieldUpperCase!.Name);
}

[Test]
Expand Down Expand Up @@ -3883,17 +3883,17 @@ public void GetApiVersionFieldForService_MultiService_SameLastSegment_ProducesUn
// This should not crash — previously it threw due to duplicate field names
Assert.DoesNotThrow(() => _ = clientProvider!.Fields);

// Verify we have two distinct api version fields using the full namespace
// Verify we have two distinct api version fields using the shortest unique namespace suffix
var apiVersionFields = clientProvider!.Fields
.Where(f => f.Name.Contains("ApiVersion", StringComparison.OrdinalIgnoreCase))
.OrderBy(f => f.Name)
.ToList();
Assert.AreEqual(2, apiVersionFields.Count);
Assert.AreNotEqual(apiVersionFields[0].Name, apiVersionFields[1].Name);

// Full namespace produces unique names: "Azure.ServiceOne.Tests" → "AzureServiceOneTests"
Assert.AreEqual("_azureServiceOneTestsApiVersion", apiVersionFields[0].Name);
Assert.AreEqual("_azureServiceTwoTestsApiVersion", apiVersionFields[1].Name);
// Shortest unique suffix: "ServiceOne.Tests" → "ServiceOneTests"
Assert.AreEqual("_serviceOneTestsApiVersion", apiVersionFields[0].Name);
Assert.AreEqual("_serviceTwoTestsApiVersion", apiVersionFields[1].Name);
}

[TestCase("{endpoint}")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ public partial class TestClient
{
private readonly global::System.Uri _endpoint;
private readonly string _subscriptionId;
private readonly string _sampleServiceAApiVersion;
private readonly string _sampleServiceBApiVersion;
private readonly string _serviceAApiVersion;
private readonly string _serviceBApiVersion;
private global::Sample.ServiceA.ServiceA _cachedServiceA;
private global::Sample.ServiceB.ServiceB _cachedServiceB;

Expand Down Expand Up @@ -44,8 +44,8 @@ internal TestClient(global::System.ClientModel.Primitives.AuthenticationPolicy a
{
Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty<global::System.ClientModel.Primitives.PipelinePolicy>(), new global::System.ClientModel.Primitives.PipelinePolicy[] { new global::System.ClientModel.Primitives.UserAgentPolicy(typeof(global::Sample.TestClient).Assembly) }, Array.Empty<global::System.ClientModel.Primitives.PipelinePolicy>());
}
_sampleServiceAApiVersion = options.SampleServiceAApiVersion;
_sampleServiceBApiVersion = options.SampleServiceBApiVersion;
_serviceAApiVersion = options.ServiceAApiVersion;
_serviceBApiVersion = options.ServiceBApiVersion;
}

public TestClient(global::System.Uri endpoint, string subscriptionId, global::Sample.TestClientOptions options) : this(null, endpoint, subscriptionId, options)
Expand All @@ -56,12 +56,12 @@ public TestClient(global::System.Uri endpoint, string subscriptionId, global::Sa

public virtual global::Sample.ServiceA.ServiceA GetServiceAClient()
{
return (global::System.Threading.Volatile.Read(ref _cachedServiceA) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedServiceA, new global::Sample.ServiceA.ServiceA(Pipeline, _endpoint, _sampleServiceAApiVersion, _subscriptionId), null) ?? _cachedServiceA));
return (global::System.Threading.Volatile.Read(ref _cachedServiceA) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedServiceA, new global::Sample.ServiceA.ServiceA(Pipeline, _endpoint, _serviceAApiVersion, _subscriptionId), null) ?? _cachedServiceA));
}

public virtual global::Sample.ServiceB.ServiceB GetServiceBClient()
{
return (global::System.Threading.Volatile.Read(ref _cachedServiceB) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedServiceB, new global::Sample.ServiceB.ServiceB(Pipeline, _endpoint, _sampleServiceBApiVersion, _subscriptionId), null) ?? _cachedServiceB));
return (global::System.Threading.Volatile.Read(ref _cachedServiceB) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedServiceB, new global::Sample.ServiceB.ServiceB(Pipeline, _endpoint, _serviceBApiVersion, _subscriptionId), null) ?? _cachedServiceB));
}
}
}
Loading
Loading