diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs index de172f444a8..e916fb35c29 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs @@ -187,7 +187,53 @@ protected override string BuildNamespace() => string.IsNullOrEmpty(_inputModel.N protected override CSharpType? BuildBaseType() { - return BaseModelProvider?.Type; + if (CustomCodeView?.BaseType != null) + { + var customBase = CustomCodeView.BaseType; + + // If the custom base type doesn't have a resolved namespace, then try to resolve it from the input model map. + // This will happen if a model is customized to inherit from another generated model, but that generated model + // was not also defined in custom code so Roslyn does not recognize it. + if (string.IsNullOrEmpty(customBase.Namespace)) + { + if (CodeModelGenerator.Instance.TypeFactory.TypeProvidersByName.TryGetValue( + customBase.Name, out var resolvedProvider) && + resolvedProvider is ModelProvider resolvedModel) + { + return resolvedModel.Type; + } + + // Force-create all input models so that visitors run (which may rename models + // via TypeProvider.Update) and TypeProvidersByName is fully populated. + foreach (var model in CodeModelGenerator.Instance.InputLibrary.InputNamespace.Models) + { + CodeModelGenerator.Instance.TypeFactory.CreateModel(model); + } + + if (CodeModelGenerator.Instance.TypeFactory.TypeProvidersByName.TryGetValue( + customBase.Name, out resolvedProvider) && + resolvedProvider is ModelProvider resolvedAfterCreate) + { + return resolvedAfterCreate.Type; + } + } + + if (CodeModelGenerator.Instance.TypeFactory.CSharpTypeMap.TryGetValue( + customBase, out var mappedProvider) && + mappedProvider is ModelProvider mappedModel) + { + return mappedModel.Type; + } + + return customBase; + } + + if (_inputModel.BaseModel == null) + { + return null; + } + + return CodeModelGenerator.Instance.TypeFactory.CreateModel(_inputModel.BaseModel)?.Type; } protected override TypeProvider[] BuildSerializationProviders() @@ -293,63 +339,16 @@ private static bool IsDiscriminator(InputProperty property) private ModelProvider? BuildBaseModelProvider() { - // consider models that have been customized to inherit from a different generated model - if (CustomCodeView?.BaseType != null) - { - var baseType = CustomCodeView.BaseType; - - // If the custom base type doesn't have a resolved namespace, then try to resolve it from the input model map. - // This will happen if a model is customized to inherit from another generated model, but that generated model - // was not also defined in custom code so Roslyn does not recognize it. - if (string.IsNullOrEmpty(baseType.Namespace)) - { - // Cheap check: the base model may already be created and registered under the right name. - if (CodeModelGenerator.Instance.TypeFactory.TypeProvidersByName.TryGetValue( - baseType.Name, out var resolvedProvider) && - resolvedProvider is ModelProvider resolvedModel) - { - return resolvedModel; - } - - // Force-create all input models so that visitors run (which may rename models - // via TypeProvider.Update) and TypeProvidersByName is fully populated. - // This is a no-op for models that have already been created. - foreach (var model in CodeModelGenerator.Instance.InputLibrary.InputNamespace.Models) - { - CodeModelGenerator.Instance.TypeFactory.CreateModel(model); - } - - if (CodeModelGenerator.Instance.TypeFactory.TypeProvidersByName.TryGetValue( - baseType.Name, out resolvedProvider) && - resolvedProvider is ModelProvider resolvedAfterCreate) - { - return resolvedAfterCreate; - } - } - - // Try to find the base type in the CSharpTypeMap - if (baseType != null && CodeModelGenerator.Instance.TypeFactory.CSharpTypeMap.TryGetValue( - baseType, - out var customBaseType) && - customBaseType is ModelProvider customBaseModel) - { - return customBaseModel; - } - - // If the custom base type has a namespace (external type), we don't return it here - // as it's handled by BuildBaseTypeProvider() which returns a TypeProvider - if (!string.IsNullOrEmpty(baseType?.Namespace)) - { - return null; - } - } - - if (_inputModel.BaseModel == null) + var baseType = BaseType; + if (baseType is null) { return null; } - return CodeModelGenerator.Instance.TypeFactory.CreateModel(_inputModel.BaseModel); + return CodeModelGenerator.Instance.TypeFactory.CSharpTypeMap.TryGetValue(baseType, out var provider) + && provider is ModelProvider modelProvider + ? modelProvider + : null; } private List BuildAdditionalPropertyFields() diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs index 7c8e88e49d1..39f189fb0af 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs @@ -422,6 +422,107 @@ public void BuildBaseType() Assert.AreEqual(baseModel!.Type, derivedModel!.Type.BaseType); } + [Test] + public void OverridingBuildBaseType_AutoResolvesBaseModelProviderForGeneratedModel() + { + var inputBase = InputFactory.Model("baseModel", usage: InputModelTypeUsage.Input, properties: []); + var inputDerived = InputFactory.Model("derivedModel", usage: InputModelTypeUsage.Input, properties: []); + ModelProvider? baseProvider = null; + MockHelpers.LoadMockGenerator(createModelCore: input => + { + if (input == inputBase) + { + return baseProvider = new ModelProvider(input); + } + if (input == inputDerived) + { + return new BuildBaseTypeOverridingModelProvider(input, baseProvider!.Type); + } + return null; + }); + + var actualBase = CodeModelGenerator.Instance.TypeFactory.CreateModel(inputBase); + var actualDerived = CodeModelGenerator.Instance.TypeFactory.CreateModel(inputDerived); + + Assert.IsNotNull(actualBase); + Assert.IsNotNull(actualDerived); + Assert.AreEqual(actualBase!.Type, actualDerived!.BaseType); + Assert.AreSame(actualBase, actualDerived.BaseModelProvider); + } + + [Test] + public void OverridingBuildBaseType_AutoResolvesBaseModelProviderToNullForFrameworkType() + { + var inputDerived = InputFactory.Model("derivedModel", usage: InputModelTypeUsage.Input, properties: []); + var frameworkBase = new CSharpType(typeof(InvalidOperationException)); + MockHelpers.LoadMockGenerator(createModelCore: input => + input == inputDerived ? new BuildBaseTypeOverridingModelProvider(input, frameworkBase) : null); + + var actualDerived = CodeModelGenerator.Instance.TypeFactory.CreateModel(inputDerived); + + Assert.IsNotNull(actualDerived); + Assert.AreEqual(frameworkBase, actualDerived!.BaseType); + Assert.IsNull(actualDerived.BaseModelProvider); + } + + [Test] + public void BaseModelProvider_DefaultResolvesViaCSharpTypeMap() + { + var inputBase = InputFactory.Model("baseModel", usage: InputModelTypeUsage.Input, properties: []); + var inputDerived = InputFactory.Model("derivedModel", usage: InputModelTypeUsage.Input, properties: [], baseModel: inputBase); + + var derivedProvider = CodeModelGenerator.Instance.TypeFactory.CreateModel(inputDerived); + Assert.IsNotNull(derivedProvider); + Assert.IsNotNull(derivedProvider!.BaseModelProvider); + Assert.AreEqual(derivedProvider.BaseModelProvider!.Type, derivedProvider.BaseType); + } + + [Test] + public void BaseModelProvider_NullWhenNoBase() + { + var inputModel = InputFactory.Model("standaloneModel", usage: InputModelTypeUsage.Input, properties: []); + var modelProvider = CodeModelGenerator.Instance.TypeFactory.CreateModel(inputModel); + + Assert.IsNotNull(modelProvider); + Assert.IsNull(modelProvider!.BaseType); + Assert.IsNull(modelProvider.BaseModelProvider); + } + + [Test] + public void OverridingBuildBaseType_AutoResolvesBaseModelProviderToNullForNonModelTypeProvider() + { + var inputDerived = InputFactory.Model("derivedModel", usage: InputModelTypeUsage.Input, properties: []); + var nonModelTypeProvider = new NonModelTypeProvider(); + MockHelpers.LoadMockGenerator(createModelCore: input => + input == inputDerived ? new BuildBaseTypeOverridingModelProvider(input, nonModelTypeProvider.Type) : null); + CodeModelGenerator.Instance.TypeFactory.CSharpTypeMap[nonModelTypeProvider.Type] = nonModelTypeProvider; + + var actualDerived = CodeModelGenerator.Instance.TypeFactory.CreateModel(inputDerived); + + Assert.IsNotNull(actualDerived); + Assert.AreEqual(nonModelTypeProvider.Type, actualDerived!.BaseType); + Assert.IsNull(actualDerived.BaseModelProvider); + } + + private class NonModelTypeProvider : TypeProvider + { + protected override string BuildRelativeFilePath() => "."; + protected override string BuildName() => "NonModelBase"; + protected override string BuildNamespace() => "Custom.Namespace"; + } + + private class BuildBaseTypeOverridingModelProvider : ModelProvider + { + private readonly CSharpType? _redirectedBaseType; + + public BuildBaseTypeOverridingModelProvider(InputModelType inputModel, CSharpType? redirectedBaseType) : base(inputModel) + { + _redirectedBaseType = redirectedBaseType; + } + + protected override CSharpType? BuildBaseType() => _redirectedBaseType; + } + [Test] public void BuildModelAsStruct() {