diff --git a/src/Service.Tests/UnitTests/HealthCheckUtilitiesUnitTests.cs b/src/Service.Tests/UnitTests/HealthCheckUtilitiesUnitTests.cs new file mode 100644 index 0000000000..db9e2d9409 --- /dev/null +++ b/src/Service.Tests/UnitTests/HealthCheckUtilitiesUnitTests.cs @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Service.HealthCheck; +using Microsoft.Extensions.Logging; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests +{ + /// + /// Unit tests for health check utility methods. + /// + [TestClass] + public class HealthCheckUtilitiesUnitTests + { + /// + /// Tests that connection strings are properly normalized for supported database types. + /// + [TestMethod] + [DataRow( + DatabaseType.PostgreSQL, + "Host=localhost;Port=5432;Database=testdb;Username=testuser;Password=XXXX", + "Host=localhost", + "Database=testdb", + DisplayName = "PostgreSQL connection string normalization")] + [DataRow( + DatabaseType.MSSQL, + "Server=localhost;Database=testdb;User Id=testuser;Password=XXXX", + "Data Source=localhost", + "Initial Catalog=testdb", + DisplayName = "MSSQL connection string normalization")] + [DataRow( + DatabaseType.DWSQL, + "Server=localhost;Database=testdb;User Id=testuser;Password=XXXX", + "Data Source=localhost", + "Initial Catalog=testdb", + DisplayName = "DWSQL connection string normalization")] + [DataRow( + DatabaseType.MySQL, + "Server=localhost;Port=3306;Database=testdb;Uid=testuser;Pwd=XXXX", + "Server=localhost", + "Database=testdb", + DisplayName = "MySQL connection string normalization")] + public void NormalizeConnectionString_SupportedDatabases_Success( + DatabaseType dbType, + string connectionString, + string expectedServerPart, + string expectedDatabasePart) + { + // Act + string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType); + + // Assert + Assert.IsNotNull(result); + Assert.IsTrue(result.Contains(expectedServerPart)); + Assert.IsTrue(result.Contains(expectedDatabasePart)); + } + + /// + /// Tests that unsupported database types return the original connection string. + /// + [TestMethod] + public void NormalizeConnectionString_UnsupportedType_ReturnsOriginal() + { + // Arrange + string connectionString = "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test"; + DatabaseType dbType = DatabaseType.CosmosDB_NoSQL; + + // Act + string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType); + + // Assert + Assert.AreEqual(connectionString, result); + } + + /// + /// Tests that malformed connection strings are handled gracefully. + /// + [TestMethod] + [DataRow(DatabaseType.PostgreSQL, true, DisplayName = "PostgreSQL malformed string with logger")] + [DataRow(DatabaseType.MSSQL, true, DisplayName = "MSSQL malformed string with logger")] + [DataRow(DatabaseType.MySQL, false, DisplayName = "MySQL malformed string without logger")] + public void NormalizeConnectionString_MalformedString_ReturnsOriginal( + DatabaseType dbType, + bool useLogger) + { + // Arrange + string malformedConnectionString = "InvalidConnectionString;NoEquals"; + Mock? mockLogger = useLogger ? new Mock() : null; + + // Act + string result = HealthCheck.Utilities.NormalizeConnectionString( + malformedConnectionString, + dbType, + mockLogger?.Object); + + // Assert + Assert.AreEqual(malformedConnectionString, result); + if (useLogger && mockLogger != null) + { + mockLogger.Verify( + x => x.Log( + LogLevel.Warning, + It.IsAny(), + It.Is((v, t) => true), + It.IsAny(), + It.Is>((v, t) => true)), + Times.Once); + } + } + + /// + /// Tests that PostgreSQL connection strings with lowercase keywords are normalized correctly. + /// This is the specific bug that was reported - lowercase 'host' was not supported. + /// + [TestMethod] + public void NormalizeConnectionString_PostgreSQL_LowercaseKeywords_Success() + { + // Arrange + string connectionString = "host=localhost;port=5432;database=mydb;username=myuser;password=XXXX"; + DatabaseType dbType = DatabaseType.PostgreSQL; + + // Act + string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType); + + // Assert + Assert.IsNotNull(result); + // NpgsqlConnectionStringBuilder should normalize lowercase keywords to proper format + Assert.IsTrue(result.Contains("Host=localhost") || result.Contains("host=localhost")); + Assert.IsTrue(result.Contains("Database=mydb") || result.Contains("database=mydb")); + } + + /// + /// Tests that empty connection strings are handled gracefully. + /// + [TestMethod] + public void NormalizeConnectionString_EmptyString_ReturnsEmpty() + { + // Arrange + string connectionString = string.Empty; + DatabaseType dbType = DatabaseType.PostgreSQL; + + // Act + string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType); + + // Assert + Assert.AreEqual(string.Empty, result); + } + } +} diff --git a/src/Service/HealthCheck/HealthCheckHelper.cs b/src/Service/HealthCheck/HealthCheckHelper.cs index ab19756195..991e39983f 100644 --- a/src/Service/HealthCheck/HealthCheckHelper.cs +++ b/src/Service/HealthCheck/HealthCheckHelper.cs @@ -162,7 +162,7 @@ private async Task UpdateDataSourceHealthCheckResultsAsync(ComprehensiveHealthCh if (comprehensiveHealthCheckReport.Checks != null && runtimeConfig.DataSource.IsDatasourceHealthEnabled) { string query = Utilities.GetDatSourceQuery(runtimeConfig.DataSource.DatabaseType); - (int, string?) response = await ExecuteDatasourceQueryCheckAsync(query, runtimeConfig.DataSource.ConnectionString, Utilities.GetDbProviderFactory(runtimeConfig.DataSource.DatabaseType)); + (int, string?) response = await ExecuteDatasourceQueryCheckAsync(query, runtimeConfig.DataSource.ConnectionString, Utilities.GetDbProviderFactory(runtimeConfig.DataSource.DatabaseType), runtimeConfig.DataSource.DatabaseType); bool isResponseTimeWithinThreshold = response.Item1 >= 0 && response.Item1 < runtimeConfig.DataSource.DatasourceThresholdMs; // Add DataSource Health Check Results @@ -182,14 +182,14 @@ private async Task UpdateDataSourceHealthCheckResultsAsync(ComprehensiveHealthCh } // Executes the DB Query and keeps track of the response time and error message. - private async Task<(int, string?)> ExecuteDatasourceQueryCheckAsync(string query, string connectionString, DbProviderFactory dbProviderFactory) + private async Task<(int, string?)> ExecuteDatasourceQueryCheckAsync(string query, string connectionString, DbProviderFactory dbProviderFactory, DatabaseType databaseType) { string? errorMessage = null; if (!string.IsNullOrEmpty(query) && !string.IsNullOrEmpty(connectionString)) { Stopwatch stopwatch = new(); stopwatch.Start(); - errorMessage = await _httpUtility.ExecuteDbQueryAsync(query, connectionString, dbProviderFactory); + errorMessage = await _httpUtility.ExecuteDbQueryAsync(query, connectionString, dbProviderFactory, databaseType); stopwatch.Stop(); return string.IsNullOrEmpty(errorMessage) ? ((int)stopwatch.ElapsedMilliseconds, errorMessage) : (HealthCheckConstants.ERROR_RESPONSE_TIME_MS, errorMessage); } diff --git a/src/Service/HealthCheck/HttpUtilities.cs b/src/Service/HealthCheck/HttpUtilities.cs index 9da596ae30..2a8d7b9f3e 100644 --- a/src/Service/HealthCheck/HttpUtilities.cs +++ b/src/Service/HealthCheck/HttpUtilities.cs @@ -49,7 +49,7 @@ public HttpUtilities( } // Executes the DB query by establishing a connection to the DB. - public async Task ExecuteDbQueryAsync(string query, string connectionString, DbProviderFactory providerFactory) + public async Task ExecuteDbQueryAsync(string query, string connectionString, DbProviderFactory providerFactory, DatabaseType databaseType) { string? errorMessage = null; // Execute the query on DB and return the response time. @@ -65,7 +65,7 @@ public HttpUtilities( { try { - connection.ConnectionString = connectionString; + connection.ConnectionString = Utilities.NormalizeConnectionString(connectionString, databaseType, _logger); using (DbCommand command = connection.CreateCommand()) { command.CommandText = query; diff --git a/src/Service/HealthCheck/Utilities.cs b/src/Service/HealthCheck/Utilities.cs index 290410291e..888ffbca91 100644 --- a/src/Service/HealthCheck/Utilities.cs +++ b/src/Service/HealthCheck/Utilities.cs @@ -7,6 +7,8 @@ using System.Text.Json; using Azure.DataApiBuilder.Config.ObjectModel; using Microsoft.Data.SqlClient; +using Microsoft.Extensions.Logging; +using MySqlConnector; using Npgsql; namespace Azure.DataApiBuilder.Service.HealthCheck @@ -69,5 +71,32 @@ public static string CreateHttpRestQuery(string entityName, int first) // "EntityName?$first=4" return $"/{entityName}?$first={first}"; } + + public static string NormalizeConnectionString(string connectionString, DatabaseType dbType, ILogger? logger = null) + { + try + { + switch (dbType) + { + case DatabaseType.PostgreSQL: + return new NpgsqlConnectionStringBuilder(connectionString).ToString(); + case DatabaseType.MySQL: + return new MySqlConnectionStringBuilder(connectionString).ToString(); + case DatabaseType.MSSQL: + case DatabaseType.DWSQL: + return new SqlConnectionStringBuilder(connectionString).ToString(); + default: + return connectionString; + } + } + catch (Exception ex) + { + // Log the exception if a logger is provided + logger?.LogWarning(ex, "Failed to parse connection string for database type {DatabaseType}. Returning original connection string.", dbType); + // If the connection string cannot be parsed by the builder, + // return the original string to avoid failing the health check. + return connectionString; + } + } } }