Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dotnet/src/VectorData/PgVector/PostgresCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class PostgresCollection<TKey, TRecord> : VectorStoreCollection<TKey, TRe
private readonly string _databaseName;

/// <summary>The database schema.</summary>
private readonly string _schema;
private readonly string? _schema;

/// <summary>The model for this collection.</summary>
private readonly CollectionModel _model;
Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/VectorData/PgVector/PostgresCollectionOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ public PostgresCollectionOptions()

internal PostgresCollectionOptions(PostgresCollectionOptions? source) : base(source)
{
this.Schema = source?.Schema ?? PostgresVectorStoreOptions.Default.Schema;
this.Schema = source?.Schema;
}

/// <summary>
/// Gets or sets the database schema.
/// </summary>
public string Schema { get; set; } = PostgresVectorStoreOptions.Default.Schema;
public string? Schema { get; set; }
}
100 changes: 57 additions & 43 deletions dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,36 @@ namespace Microsoft.SemanticKernel.Connectors.PgVector;
[System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")]
internal static class PostgresSqlBuilder
{
internal static void BuildDoesTableExistCommand(NpgsqlCommand command, string schema, string tableName)
internal static void BuildDoesTableExistCommand(NpgsqlCommand command, string? schema, string tableName)
{
command.CommandText = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = $1
AND table_type = 'BASE TABLE'
AND table_name = $2
""";

Debug.Assert(command.Parameters.Count == 0);
command.Parameters.Add(new() { Value = schema });

command.CommandText = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = $1
AND table_type = 'BASE TABLE'
AND table_name = $2
""";

command.Parameters.Add(new() { Value = schema ?? "public" });
command.Parameters.Add(new() { Value = tableName });
}

internal static void BuildGetTablesCommand(NpgsqlCommand command, string schema)
internal static void BuildGetTablesCommand(NpgsqlCommand command, string? schema)
{
command.CommandText = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = $1 AND table_type = 'BASE TABLE'
""";
Debug.Assert(command.Parameters.Count == 0);
command.Parameters.Add(new() { Value = schema });

command.CommandText = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = $1 AND table_type = 'BASE TABLE'
""";

command.Parameters.Add(new() { Value = schema ?? "public" });
}

internal static string BuildCreateTableSql(string schema, string tableName, CollectionModel model, Version pgVersion, bool ifNotExists = true)
internal static string BuildCreateTableSql(string? schema, string tableName, CollectionModel model, Version pgVersion, bool ifNotExists = true)
{
if (string.IsNullOrWhiteSpace(tableName))
{
Expand All @@ -61,7 +64,7 @@ internal static string BuildCreateTableSql(string schema, string tableName, Coll
{
createTableCommand.Append("IF NOT EXISTS ");
}
createTableCommand.AppendIdentifier(schema).Append('.').AppendIdentifier(tableName).AppendLine(" (");
createTableCommand.AppendTableName(schema, tableName).AppendLine(" (");

// Add the key column
var keyStoreType = PostgresPropertyMapping.GetPostgresTypeName(model.KeyProperty).PgType;
Expand Down Expand Up @@ -119,7 +122,7 @@ internal static string BuildCreateTableSql(string schema, string tableName, Coll
}

/// <inheritdoc />
internal static string BuildCreateIndexSql(string schema, string tableName, string columnName, string indexKind, string distanceFunction, bool isVector, bool isFullText, string? fullTextLanguage, bool ifNotExists)
internal static string BuildCreateIndexSql(string? schema, string tableName, string columnName, string indexKind, string distanceFunction, bool isVector, bool isFullText, string? fullTextLanguage, bool ifNotExists)
{
var indexName = $"{tableName}_{columnName}_index";

Expand All @@ -134,14 +137,14 @@ internal static string BuildCreateIndexSql(string schema, string tableName, stri
{
// Create a GIN index for full-text search
var language = fullTextLanguage ?? PostgresConstants.DefaultFullTextSearchLanguage;
sql.AppendIdentifier(indexName).Append(" ON ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName)
sql.AppendIdentifier(indexName).Append(" ON ").AppendTableName(schema, tableName)
.Append(" USING GIN (to_tsvector(").AppendLiteral(language).Append(", ").AppendIdentifier(columnName).Append("))");
return sql.ToString();
}

if (!isVector)
{
sql.AppendIdentifier(indexName).Append(" ON ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName)
sql.AppendIdentifier(indexName).Append(" ON ").AppendTableName(schema, tableName)
.Append(" (").AppendIdentifier(columnName).Append(')');
return sql.ToString();
}
Expand All @@ -167,23 +170,23 @@ internal static string BuildCreateIndexSql(string schema, string tableName, stri
_ => throw new NotSupportedException($"Distance function {distanceFunction} is not supported.")
};

sql.AppendIdentifier(indexName).Append(" ON ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName)
sql.AppendIdentifier(indexName).Append(" ON ").AppendTableName(schema, tableName)
.Append(" USING ").Append(indexTypeName).Append(" (").AppendIdentifier(columnName).Append(' ').Append(indexOps).Append(')');
return sql.ToString();
}

/// <inheritdoc />
internal static void BuildDropTableCommand(NpgsqlCommand command, string schema, string tableName)
internal static void BuildDropTableCommand(NpgsqlCommand command, string? schema, string tableName)
{
StringBuilder sql = new();
sql.Append("DROP TABLE IF EXISTS ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName);
sql.Append("DROP TABLE IF EXISTS ").AppendTableName(schema, tableName);
command.CommandText = sql.ToString();
}

/// <inheritdoc />
internal static bool BuildUpsertCommand<TKey>(
NpgsqlBatch batch,
string schema,
string? schema,
string tableName,
CollectionModel model,
IEnumerable<object> records,
Expand Down Expand Up @@ -260,9 +263,7 @@ string GenerateSingleUpsertSql(bool autoGeneratedKey)

sqlBuilder
.Append("INSERT INTO ")
.AppendIdentifier(schema)
.Append('.')
.AppendIdentifier(tableName)
.AppendTableName(schema, tableName)
.Append(" (");

var i = 0;
Expand Down Expand Up @@ -345,7 +346,7 @@ string GenerateSingleUpsertSql(bool autoGeneratedKey)
}

/// <inheritdoc />
internal static void BuildGetCommand<TKey>(NpgsqlCommand command, string schema, string tableName, CollectionModel model, TKey key, bool includeVectors = false)
internal static void BuildGetCommand<TKey>(NpgsqlCommand command, string? schema, string tableName, CollectionModel model, TKey key, bool includeVectors = false)
where TKey : notnull
{
StringBuilder sql = new();
Expand All @@ -360,7 +361,7 @@ internal static void BuildGetCommand<TKey>(NpgsqlCommand command, string schema,
sql.AppendIdentifier(model.Properties[i].StorageName);
}

sql.AppendLine().Append("FROM ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName).AppendLine()
sql.AppendLine().Append("FROM ").AppendTableName(schema, tableName).AppendLine()
.Append("WHERE ").AppendIdentifier(model.KeyProperty.StorageName).Append(" = $1;");

command.CommandText = sql.ToString();
Expand All @@ -369,7 +370,7 @@ internal static void BuildGetCommand<TKey>(NpgsqlCommand command, string schema,
}

/// <inheritdoc />
internal static void BuildGetBatchCommand<TKey>(NpgsqlCommand command, string schema, string tableName, CollectionModel model, List<TKey> keys, bool includeVectors = false)
internal static void BuildGetBatchCommand<TKey>(NpgsqlCommand command, string? schema, string tableName, CollectionModel model, List<TKey> keys, bool includeVectors = false)
where TKey : notnull
{
NpgsqlDbType? keyType = PostgresPropertyMapping.GetNpgsqlDbType(model.KeyProperty) ?? throw new UnreachableException($"Unsupported key type {model.KeyProperty.Type.Name}");
Expand All @@ -393,7 +394,7 @@ internal static void BuildGetBatchCommand<TKey>(NpgsqlCommand command, string sc
sql.AppendIdentifier(property.StorageName);
}

sql.AppendLine().Append("FROM ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName).AppendLine()
sql.AppendLine().Append("FROM ").AppendTableName(schema, tableName).AppendLine()
.Append("WHERE ").AppendIdentifier(model.KeyProperty.StorageName).Append(" = ANY($1);");

command.CommandText = sql.ToString();
Expand All @@ -406,10 +407,10 @@ internal static void BuildGetBatchCommand<TKey>(NpgsqlCommand command, string sc
}

/// <inheritdoc />
internal static void BuildDeleteCommand<TKey>(NpgsqlCommand command, string schema, string tableName, string keyColumn, TKey key)
internal static void BuildDeleteCommand<TKey>(NpgsqlCommand command, string? schema, string tableName, string keyColumn, TKey key)
{
StringBuilder sql = new();
sql.Append("DELETE FROM ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName).AppendLine()
sql.Append("DELETE FROM ").AppendTableName(schema, tableName).AppendLine()
.Append("WHERE ").AppendIdentifier(keyColumn).Append(" = $1;");

command.CommandText = sql.ToString();
Expand All @@ -418,7 +419,7 @@ internal static void BuildDeleteCommand<TKey>(NpgsqlCommand command, string sche
}

/// <inheritdoc />
internal static void BuildDeleteBatchCommand<TKey>(NpgsqlCommand command, string schema, string tableName, KeyPropertyModel keyProperty, List<TKey> keys)
internal static void BuildDeleteBatchCommand<TKey>(NpgsqlCommand command, string? schema, string tableName, KeyPropertyModel keyProperty, List<TKey> keys)
{
NpgsqlDbType? keyType = PostgresPropertyMapping.GetNpgsqlDbType(keyProperty) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}");

Expand All @@ -431,7 +432,7 @@ internal static void BuildDeleteBatchCommand<TKey>(NpgsqlCommand command, string
}

StringBuilder sql = new();
sql.Append("DELETE FROM ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName).AppendLine()
sql.Append("DELETE FROM ").AppendTableName(schema, tableName).AppendLine()
.Append("WHERE ").AppendIdentifier(keyProperty.StorageName).Append(" = ANY($1);");

command.CommandText = sql.ToString();
Expand Down Expand Up @@ -504,7 +505,7 @@ private static (string Condition, List<object> Parameters) GenerateFilterConditi
#pragma warning disable CS0618 // VectorSearchFilter is obsolete
/// <inheritdoc />
internal static void BuildGetNearestMatchCommand<TRecord>(
NpgsqlCommand command, string schema, string tableName, CollectionModel model, VectorPropertyModel vectorProperty, object vectorValue,
NpgsqlCommand command, string? schema, string tableName, CollectionModel model, VectorPropertyModel vectorProperty, object vectorValue,
VectorSearchFilter? legacyFilter, Expression<Func<TRecord, bool>>? newFilter, int? skip, bool includeVectors, int limit,
double? scoreThreshold = null)
{
Expand All @@ -528,7 +529,7 @@ internal static void BuildGetNearestMatchCommand<TRecord>(
StringBuilder sql = new();
sql.Append("SELECT ").Append(columns).Append(", ").AppendIdentifier(vectorProperty.StorageName)
.Append(' ').Append(distanceOp).Append(" $1 AS ").AppendIdentifier(PostgresConstants.DistanceColumnName).AppendLine()
.Append("FROM ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName)
.Append("FROM ").AppendTableName(schema, tableName)
.Append(' ').AppendLine(where)
.Append("ORDER BY ").AppendLine(PostgresConstants.DistanceColumnName)
.Append("LIMIT ").Append(limit);
Expand Down Expand Up @@ -608,7 +609,7 @@ or DistanceFunction.HammingDistance
}

internal static void BuildSelectWhereCommand<TRecord>(
NpgsqlCommand command, string schema, string tableName, CollectionModel model,
NpgsqlCommand command, string? schema, string tableName, CollectionModel model,
Expression<Func<TRecord, bool>> filter, int top, FilteredRecordRetrievalOptions<TRecord> options)
{
StringBuilder query = new(200);
Expand All @@ -627,7 +628,7 @@ internal static void BuildSelectWhereCommand<TRecord>(
}
}
query.AppendLine();
query.Append("FROM ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName).AppendLine();
query.Append("FROM ").AppendTableName(schema, tableName).AppendLine();

PostgresFilterTranslator translator = new(model, filter, startParamIndex: 1, query);
translator.Translate(appendWhere: true);
Expand Down Expand Up @@ -678,6 +679,19 @@ internal static (string Condition, List<object> Parameters) GenerateNewFilterCon
return (translator.Clause.ToString(), translator.ParameterValues);
}

/// <summary>
/// Appends a schema-qualified table name. If schema is null or empty, omits the schema prefix.
/// </summary>
private static StringBuilder AppendTableName(this StringBuilder sb, string? schema, string tableName)
{
if (!string.IsNullOrEmpty(schema))
{
sb.AppendIdentifier(schema!).Append('.');
}

return sb.AppendIdentifier(tableName);
}

#pragma warning disable CS0618 // VectorSearchFilter is obsolete
internal static (string Clause, List<object> Parameters) GenerateLegacyFilterWhereClause(CollectionModel model, VectorSearchFilter legacyFilter, int startParamIndex)
{
Expand Down Expand Up @@ -739,7 +753,7 @@ internal static (string Condition, List<object> Parameters) GenerateLegacyFilter
/// Builds a hybrid search command that combines vector similarity search with full-text keyword search using RRF (Reciprocal Rank Fusion).
/// </summary>
internal static void BuildHybridSearchCommand<TRecord>(
NpgsqlCommand command, string schema, string tableName, CollectionModel model,
NpgsqlCommand command, string? schema, string tableName, CollectionModel model,
VectorPropertyModel vectorProperty, DataPropertyModel textProperty,
object vectorValue, ICollection<string> keywords,
VectorSearchFilter? legacyFilter, Expression<Func<TRecord, bool>>? newFilter,
Expand Down Expand Up @@ -773,7 +787,7 @@ internal static void BuildHybridSearchCommand<TRecord>(

// Build the full table name
var fullTableName = new StringBuilder()
.AppendIdentifier(schema).Append('.').AppendIdentifier(tableName)
.AppendTableName(schema, tableName)
.ToString();

// Use a larger internal limit for the CTEs to get better ranking, then limit final results
Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/VectorData/PgVector/PostgresVectorStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public sealed class PostgresVectorStore : VectorStore
private readonly string _databaseName;

/// <summary>The database schema.</summary>
private readonly string _schema;
private readonly string? _schema;

private readonly IEmbeddingGenerator? _embeddingGenerator;

Expand All @@ -44,7 +44,7 @@ public PostgresVectorStore(NpgsqlDataSource dataSource, bool ownsDataSource, Pos
{
Verify.NotNull(dataSource);

this._schema = options?.Schema ?? PostgresVectorStoreOptions.Default.Schema;
this._schema = options?.Schema;
this._embeddingGenerator = options?.EmbeddingGenerator;
this._dataSource = dataSource;
this._dataSourceArc = ownsDataSource ? new NpgsqlDataSourceArc(dataSource) : null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ public PostgresVectorStoreOptions()

internal PostgresVectorStoreOptions(PostgresVectorStoreOptions? source)
{
this.Schema = source?.Schema ?? Default.Schema;
this.Schema = source?.Schema;
this.EmbeddingGenerator = source?.EmbeddingGenerator;
}

/// <summary>
/// Gets or sets the database schema.
/// </summary>
public string Schema { get; set; } = "public";
public string? Schema { get; set; }

/// <summary>
/// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store.
Expand Down
Loading
Loading