Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions src/Service.Tests/UnitTests/HealthCheckUtilitiesUnitTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#nullable enable

using System;
using Azure.DataApiBuilder.Config.ObjectModel;
using Azure.DataApiBuilder.Service.HealthCheck;
using Microsoft.Extensions.Logging;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;

namespace Azure.DataApiBuilder.Service.Tests.UnitTests
{
/// <summary>
/// Unit tests for health check utility methods.
/// </summary>
[TestClass]
public class HealthCheckUtilitiesUnitTests
{
/// <summary>
/// Tests that PostgreSQL connection strings are properly normalized.
/// </summary>
[TestMethod]
public void NormalizeConnectionString_PostgreSQL_Success()
{
// Arrange
string connectionString = "Host=localhost;Port=5432;Database=testdb;Username=testuser;Password=testpass";
DatabaseType dbType = DatabaseType.PostgreSQL;

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType);

// Assert
Assert.IsNotNull(result);
Assert.IsTrue(result.Contains("Host=localhost"));
Assert.IsTrue(result.Contains("Database=testdb"));
}

/// <summary>
/// Tests that MSSQL connection strings are properly normalized.
/// </summary>
[TestMethod]
public void NormalizeConnectionString_MSSQL_Success()
{
// Arrange
string connectionString = "Server=localhost;Database=testdb;User Id=testuser;Password=testpass";
DatabaseType dbType = DatabaseType.MSSQL;

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType);

// Assert
Assert.IsNotNull(result);
Assert.IsTrue(result.Contains("Data Source=localhost"));
Assert.IsTrue(result.Contains("Initial Catalog=testdb"));
}

/// <summary>
/// Tests that DWSQL connection strings are properly normalized.
/// </summary>
[TestMethod]
public void NormalizeConnectionString_DWSQL_Success()
{
// Arrange
string connectionString = "Server=localhost;Database=testdb;User Id=testuser;Password=testpass";
DatabaseType dbType = DatabaseType.DWSQL;

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType);

// Assert
Assert.IsNotNull(result);
Assert.IsTrue(result.Contains("Data Source=localhost"));
Assert.IsTrue(result.Contains("Initial Catalog=testdb"));
}

/// <summary>
/// Tests that MySQL connection strings are properly normalized.
/// </summary>
[TestMethod]
public void NormalizeConnectionString_MySQL_Success()
{
// Arrange
string connectionString = "Server=localhost;Port=3306;Database=testdb;Uid=testuser;Pwd=testpass";
DatabaseType dbType = DatabaseType.MySQL;

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType);

// Assert
Assert.IsNotNull(result);
Assert.IsTrue(result.Contains("Server=localhost"));
Assert.IsTrue(result.Contains("Database=testdb"));
}

/// <summary>
/// Tests that unsupported database types return the original connection string.
/// </summary>
[TestMethod]
public void NormalizeConnectionString_UnsupportedType_ReturnsOriginal()
{
// Arrange
string connectionString = "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test";
DatabaseType dbType = DatabaseType.CosmosDB_NoSQL;

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType);

// Assert
Assert.AreEqual(connectionString, result);
}

/// <summary>
/// Tests that malformed connection strings are handled gracefully.
/// </summary>
[TestMethod]
public void NormalizeConnectionString_MalformedString_ReturnsOriginalAndLogs()
{
// Arrange
string malformedConnectionString = "InvalidConnectionString;NoEquals";
DatabaseType dbType = DatabaseType.PostgreSQL;
Mock<ILogger> mockLogger = new Mock<ILogger>();

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(malformedConnectionString, dbType, mockLogger.Object);

// Assert
Assert.AreEqual(malformedConnectionString, result);
mockLogger.Verify(
x => x.Log(
LogLevel.Warning,
It.IsAny<EventId>(),
It.Is<It.IsAnyType>((v, t) => true),
It.IsAny<Exception>(),
It.Is<Func<It.IsAnyType, Exception?, string>>((v, t) => true)),
Times.Once);
}

/// <summary>
/// Tests that null logger is handled gracefully.
/// </summary>
[TestMethod]
public void NormalizeConnectionString_MalformedString_NullLogger_ReturnsOriginal()
{
// Arrange
string malformedConnectionString = "InvalidConnectionString;NoEquals";
DatabaseType dbType = DatabaseType.MSSQL;

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(malformedConnectionString, dbType, null);

// Assert
Assert.AreEqual(malformedConnectionString, result);
}

/// <summary>
/// Tests that PostgreSQL connection strings with lowercase keywords are normalized correctly.
/// This is the specific bug that was reported - lowercase 'host' was not supported.
/// </summary>
[TestMethod]
public void NormalizeConnectionString_PostgreSQL_LowercaseKeywords_Success()
{
// Arrange
string connectionString = "host=localhost;port=5432;database=mydb;username=myuser;password=mypass";
DatabaseType dbType = DatabaseType.PostgreSQL;

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType);

// Assert
Assert.IsNotNull(result);
// NpgsqlConnectionStringBuilder should normalize lowercase keywords to proper format
Assert.IsTrue(result.Contains("Host=localhost") || result.Contains("host=localhost"));
Assert.IsTrue(result.Contains("Database=mydb") || result.Contains("database=mydb"));
}

/// <summary>
/// Tests that empty connection strings are handled gracefully.
/// </summary>
[TestMethod]
public void NormalizeConnectionString_EmptyString_ReturnsEmpty()
{
// Arrange
string connectionString = string.Empty;
DatabaseType dbType = DatabaseType.PostgreSQL;

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType);

// Assert
Assert.AreEqual(string.Empty, result);
}
}
}
6 changes: 3 additions & 3 deletions src/Service/HealthCheck/HealthCheckHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ private async Task UpdateDataSourceHealthCheckResultsAsync(ComprehensiveHealthCh
if (comprehensiveHealthCheckReport.Checks != null && runtimeConfig.DataSource.IsDatasourceHealthEnabled)
{
string query = Utilities.GetDatSourceQuery(runtimeConfig.DataSource.DatabaseType);
(int, string?) response = await ExecuteDatasourceQueryCheckAsync(query, runtimeConfig.DataSource.ConnectionString, Utilities.GetDbProviderFactory(runtimeConfig.DataSource.DatabaseType));
(int, string?) response = await ExecuteDatasourceQueryCheckAsync(query, runtimeConfig.DataSource.ConnectionString, Utilities.GetDbProviderFactory(runtimeConfig.DataSource.DatabaseType), runtimeConfig.DataSource.DatabaseType);
bool isResponseTimeWithinThreshold = response.Item1 >= 0 && response.Item1 < runtimeConfig.DataSource.DatasourceThresholdMs;

// Add DataSource Health Check Results
Expand All @@ -182,14 +182,14 @@ private async Task UpdateDataSourceHealthCheckResultsAsync(ComprehensiveHealthCh
}

// Executes the DB Query and keeps track of the response time and error message.
private async Task<(int, string?)> ExecuteDatasourceQueryCheckAsync(string query, string connectionString, DbProviderFactory dbProviderFactory)
private async Task<(int, string?)> ExecuteDatasourceQueryCheckAsync(string query, string connectionString, DbProviderFactory dbProviderFactory, DatabaseType databaseType)
{
string? errorMessage = null;
if (!string.IsNullOrEmpty(query) && !string.IsNullOrEmpty(connectionString))
{
Stopwatch stopwatch = new();
stopwatch.Start();
errorMessage = await _httpUtility.ExecuteDbQueryAsync(query, connectionString, dbProviderFactory);
errorMessage = await _httpUtility.ExecuteDbQueryAsync(query, connectionString, dbProviderFactory, databaseType);
stopwatch.Stop();
return string.IsNullOrEmpty(errorMessage) ? ((int)stopwatch.ElapsedMilliseconds, errorMessage) : (HealthCheckConstants.ERROR_RESPONSE_TIME_MS, errorMessage);
}
Expand Down
4 changes: 2 additions & 2 deletions src/Service/HealthCheck/HttpUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public HttpUtilities(
}

// Executes the DB query by establishing a connection to the DB.
public async Task<string?> ExecuteDbQueryAsync(string query, string connectionString, DbProviderFactory providerFactory)
public async Task<string?> ExecuteDbQueryAsync(string query, string connectionString, DbProviderFactory providerFactory, DatabaseType databaseType)
{
string? errorMessage = null;
// Execute the query on DB and return the response time.
Expand All @@ -65,7 +65,7 @@ public HttpUtilities(
{
try
{
connection.ConnectionString = connectionString;
connection.ConnectionString = Utilities.NormalizeConnectionString(connectionString, databaseType, _logger);
using (DbCommand command = connection.CreateCommand())
{
command.CommandText = query;
Expand Down
29 changes: 29 additions & 0 deletions src/Service/HealthCheck/Utilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
using System.Text.Json;
using Azure.DataApiBuilder.Config.ObjectModel;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Logging;
using MySqlConnector;
using Npgsql;

namespace Azure.DataApiBuilder.Service.HealthCheck
Expand Down Expand Up @@ -69,5 +71,32 @@ public static string CreateHttpRestQuery(string entityName, int first)
// "EntityName?$first=4"
return $"/{entityName}?$first={first}";
}

public static string NormalizeConnectionString(string connectionString, DatabaseType dbType, ILogger? logger = null)
{
try
{
switch (dbType)
{
case DatabaseType.PostgreSQL:
return new NpgsqlConnectionStringBuilder(connectionString).ToString();
case DatabaseType.MySQL:
return new MySqlConnectionStringBuilder(connectionString).ToString();
case DatabaseType.MSSQL:
case DatabaseType.DWSQL:
return new SqlConnectionStringBuilder(connectionString).ToString();
default:
return connectionString;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot for sake of completeness, should you also normalize MySQL connection string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added MySQL connection string normalization in commit d94b703. The method now uses MySqlConnectionStringBuilder for MySQL databases, consistent with the pattern used for PostgreSQL and MSSQL.

}
catch (Exception ex)
{
// Log the exception if a logger is provided
logger?.LogWarning(ex, "Failed to parse connection string for database type {DatabaseType}. Returning original connection string.", dbType);
// If the connection string cannot be parsed by the builder,
// return the original string to avoid failing the health check.
return connectionString;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot at least log the exception or throw them instead of swallowing silently

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added logging in commit d94b703. The method now accepts an optional ILogger parameter and logs warnings when connection string parsing fails. The logger is passed from HttpUtilities where it's already available, providing visibility into parsing failures without exposing sensitive data in the connection string.

}
}
}