diff --git a/dotnet/src/VectorData/PgVector/PostgresCollection.cs b/dotnet/src/VectorData/PgVector/PostgresCollection.cs index 14d9cef81f69..81c6975be7d0 100644 --- a/dotnet/src/VectorData/PgVector/PostgresCollection.cs +++ b/dotnet/src/VectorData/PgVector/PostgresCollection.cs @@ -41,7 +41,7 @@ public class PostgresCollection : VectorStoreCollectionThe database schema. - private readonly string _schema; + private readonly string? _schema; /// The model for this collection. private readonly CollectionModel _model; diff --git a/dotnet/src/VectorData/PgVector/PostgresCollectionOptions.cs b/dotnet/src/VectorData/PgVector/PostgresCollectionOptions.cs index 864e4f3877cc..ecfc6d83d13e 100644 --- a/dotnet/src/VectorData/PgVector/PostgresCollectionOptions.cs +++ b/dotnet/src/VectorData/PgVector/PostgresCollectionOptions.cs @@ -20,11 +20,11 @@ public PostgresCollectionOptions() internal PostgresCollectionOptions(PostgresCollectionOptions? source) : base(source) { - this.Schema = source?.Schema ?? PostgresVectorStoreOptions.Default.Schema; + this.Schema = source?.Schema; } /// /// Gets or sets the database schema. /// - public string Schema { get; set; } = PostgresVectorStoreOptions.Default.Schema; + public string? Schema { get; set; } } diff --git a/dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs b/dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs index 6ac787f26ad3..64a320f70a6a 100644 --- a/dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs +++ b/dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs @@ -20,33 +20,36 @@ namespace Microsoft.SemanticKernel.Connectors.PgVector; [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] internal static class PostgresSqlBuilder { - internal static void BuildDoesTableExistCommand(NpgsqlCommand command, string schema, string tableName) + internal static void BuildDoesTableExistCommand(NpgsqlCommand command, string? schema, string tableName) { - command.CommandText = """ -SELECT table_name -FROM information_schema.tables -WHERE table_schema = $1 - AND table_type = 'BASE TABLE' - AND table_name = $2 -"""; - Debug.Assert(command.Parameters.Count == 0); - command.Parameters.Add(new() { Value = schema }); + + command.CommandText = """ + SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1 + AND table_type = 'BASE TABLE' + AND table_name = $2 + """; + + command.Parameters.Add(new() { Value = schema ?? "public" }); command.Parameters.Add(new() { Value = tableName }); } - internal static void BuildGetTablesCommand(NpgsqlCommand command, string schema) + internal static void BuildGetTablesCommand(NpgsqlCommand command, string? schema) { - command.CommandText = """ -SELECT table_name -FROM information_schema.tables -WHERE table_schema = $1 AND table_type = 'BASE TABLE' -"""; Debug.Assert(command.Parameters.Count == 0); - command.Parameters.Add(new() { Value = schema }); + + command.CommandText = """ + SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1 AND table_type = 'BASE TABLE' + """; + + command.Parameters.Add(new() { Value = schema ?? "public" }); } - internal static string BuildCreateTableSql(string schema, string tableName, CollectionModel model, Version pgVersion, bool ifNotExists = true) + internal static string BuildCreateTableSql(string? schema, string tableName, CollectionModel model, Version pgVersion, bool ifNotExists = true) { if (string.IsNullOrWhiteSpace(tableName)) { @@ -61,7 +64,7 @@ internal static string BuildCreateTableSql(string schema, string tableName, Coll { createTableCommand.Append("IF NOT EXISTS "); } - createTableCommand.AppendIdentifier(schema).Append('.').AppendIdentifier(tableName).AppendLine(" ("); + createTableCommand.AppendTableName(schema, tableName).AppendLine(" ("); // Add the key column var keyStoreType = PostgresPropertyMapping.GetPostgresTypeName(model.KeyProperty).PgType; @@ -119,7 +122,7 @@ internal static string BuildCreateTableSql(string schema, string tableName, Coll } /// - internal static string BuildCreateIndexSql(string schema, string tableName, string columnName, string indexKind, string distanceFunction, bool isVector, bool isFullText, string? fullTextLanguage, bool ifNotExists) + internal static string BuildCreateIndexSql(string? schema, string tableName, string columnName, string indexKind, string distanceFunction, bool isVector, bool isFullText, string? fullTextLanguage, bool ifNotExists) { var indexName = $"{tableName}_{columnName}_index"; @@ -134,14 +137,14 @@ internal static string BuildCreateIndexSql(string schema, string tableName, stri { // Create a GIN index for full-text search var language = fullTextLanguage ?? PostgresConstants.DefaultFullTextSearchLanguage; - sql.AppendIdentifier(indexName).Append(" ON ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName) + sql.AppendIdentifier(indexName).Append(" ON ").AppendTableName(schema, tableName) .Append(" USING GIN (to_tsvector(").AppendLiteral(language).Append(", ").AppendIdentifier(columnName).Append("))"); return sql.ToString(); } if (!isVector) { - sql.AppendIdentifier(indexName).Append(" ON ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName) + sql.AppendIdentifier(indexName).Append(" ON ").AppendTableName(schema, tableName) .Append(" (").AppendIdentifier(columnName).Append(')'); return sql.ToString(); } @@ -167,23 +170,23 @@ internal static string BuildCreateIndexSql(string schema, string tableName, stri _ => throw new NotSupportedException($"Distance function {distanceFunction} is not supported.") }; - sql.AppendIdentifier(indexName).Append(" ON ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName) + sql.AppendIdentifier(indexName).Append(" ON ").AppendTableName(schema, tableName) .Append(" USING ").Append(indexTypeName).Append(" (").AppendIdentifier(columnName).Append(' ').Append(indexOps).Append(')'); return sql.ToString(); } /// - internal static void BuildDropTableCommand(NpgsqlCommand command, string schema, string tableName) + internal static void BuildDropTableCommand(NpgsqlCommand command, string? schema, string tableName) { StringBuilder sql = new(); - sql.Append("DROP TABLE IF EXISTS ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName); + sql.Append("DROP TABLE IF EXISTS ").AppendTableName(schema, tableName); command.CommandText = sql.ToString(); } /// internal static bool BuildUpsertCommand( NpgsqlBatch batch, - string schema, + string? schema, string tableName, CollectionModel model, IEnumerable records, @@ -260,9 +263,7 @@ string GenerateSingleUpsertSql(bool autoGeneratedKey) sqlBuilder .Append("INSERT INTO ") - .AppendIdentifier(schema) - .Append('.') - .AppendIdentifier(tableName) + .AppendTableName(schema, tableName) .Append(" ("); var i = 0; @@ -345,7 +346,7 @@ string GenerateSingleUpsertSql(bool autoGeneratedKey) } /// - internal static void BuildGetCommand(NpgsqlCommand command, string schema, string tableName, CollectionModel model, TKey key, bool includeVectors = false) + internal static void BuildGetCommand(NpgsqlCommand command, string? schema, string tableName, CollectionModel model, TKey key, bool includeVectors = false) where TKey : notnull { StringBuilder sql = new(); @@ -360,7 +361,7 @@ internal static void BuildGetCommand(NpgsqlCommand command, string schema, sql.AppendIdentifier(model.Properties[i].StorageName); } - sql.AppendLine().Append("FROM ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName).AppendLine() + sql.AppendLine().Append("FROM ").AppendTableName(schema, tableName).AppendLine() .Append("WHERE ").AppendIdentifier(model.KeyProperty.StorageName).Append(" = $1;"); command.CommandText = sql.ToString(); @@ -369,7 +370,7 @@ internal static void BuildGetCommand(NpgsqlCommand command, string schema, } /// - internal static void BuildGetBatchCommand(NpgsqlCommand command, string schema, string tableName, CollectionModel model, List keys, bool includeVectors = false) + internal static void BuildGetBatchCommand(NpgsqlCommand command, string? schema, string tableName, CollectionModel model, List keys, bool includeVectors = false) where TKey : notnull { NpgsqlDbType? keyType = PostgresPropertyMapping.GetNpgsqlDbType(model.KeyProperty) ?? throw new UnreachableException($"Unsupported key type {model.KeyProperty.Type.Name}"); @@ -393,7 +394,7 @@ internal static void BuildGetBatchCommand(NpgsqlCommand command, string sc sql.AppendIdentifier(property.StorageName); } - sql.AppendLine().Append("FROM ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName).AppendLine() + sql.AppendLine().Append("FROM ").AppendTableName(schema, tableName).AppendLine() .Append("WHERE ").AppendIdentifier(model.KeyProperty.StorageName).Append(" = ANY($1);"); command.CommandText = sql.ToString(); @@ -406,10 +407,10 @@ internal static void BuildGetBatchCommand(NpgsqlCommand command, string sc } /// - internal static void BuildDeleteCommand(NpgsqlCommand command, string schema, string tableName, string keyColumn, TKey key) + internal static void BuildDeleteCommand(NpgsqlCommand command, string? schema, string tableName, string keyColumn, TKey key) { StringBuilder sql = new(); - sql.Append("DELETE FROM ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName).AppendLine() + sql.Append("DELETE FROM ").AppendTableName(schema, tableName).AppendLine() .Append("WHERE ").AppendIdentifier(keyColumn).Append(" = $1;"); command.CommandText = sql.ToString(); @@ -418,7 +419,7 @@ internal static void BuildDeleteCommand(NpgsqlCommand command, string sche } /// - internal static void BuildDeleteBatchCommand(NpgsqlCommand command, string schema, string tableName, KeyPropertyModel keyProperty, List keys) + internal static void BuildDeleteBatchCommand(NpgsqlCommand command, string? schema, string tableName, KeyPropertyModel keyProperty, List keys) { NpgsqlDbType? keyType = PostgresPropertyMapping.GetNpgsqlDbType(keyProperty) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); @@ -431,7 +432,7 @@ internal static void BuildDeleteBatchCommand(NpgsqlCommand command, string } StringBuilder sql = new(); - sql.Append("DELETE FROM ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName).AppendLine() + sql.Append("DELETE FROM ").AppendTableName(schema, tableName).AppendLine() .Append("WHERE ").AppendIdentifier(keyProperty.StorageName).Append(" = ANY($1);"); command.CommandText = sql.ToString(); @@ -504,7 +505,7 @@ private static (string Condition, List Parameters) GenerateFilterConditi #pragma warning disable CS0618 // VectorSearchFilter is obsolete /// internal static void BuildGetNearestMatchCommand( - NpgsqlCommand command, string schema, string tableName, CollectionModel model, VectorPropertyModel vectorProperty, object vectorValue, + NpgsqlCommand command, string? schema, string tableName, CollectionModel model, VectorPropertyModel vectorProperty, object vectorValue, VectorSearchFilter? legacyFilter, Expression>? newFilter, int? skip, bool includeVectors, int limit, double? scoreThreshold = null) { @@ -528,7 +529,7 @@ internal static void BuildGetNearestMatchCommand( StringBuilder sql = new(); sql.Append("SELECT ").Append(columns).Append(", ").AppendIdentifier(vectorProperty.StorageName) .Append(' ').Append(distanceOp).Append(" $1 AS ").AppendIdentifier(PostgresConstants.DistanceColumnName).AppendLine() - .Append("FROM ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName) + .Append("FROM ").AppendTableName(schema, tableName) .Append(' ').AppendLine(where) .Append("ORDER BY ").AppendLine(PostgresConstants.DistanceColumnName) .Append("LIMIT ").Append(limit); @@ -608,7 +609,7 @@ or DistanceFunction.HammingDistance } internal static void BuildSelectWhereCommand( - NpgsqlCommand command, string schema, string tableName, CollectionModel model, + NpgsqlCommand command, string? schema, string tableName, CollectionModel model, Expression> filter, int top, FilteredRecordRetrievalOptions options) { StringBuilder query = new(200); @@ -627,7 +628,7 @@ internal static void BuildSelectWhereCommand( } } query.AppendLine(); - query.Append("FROM ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName).AppendLine(); + query.Append("FROM ").AppendTableName(schema, tableName).AppendLine(); PostgresFilterTranslator translator = new(model, filter, startParamIndex: 1, query); translator.Translate(appendWhere: true); @@ -678,6 +679,19 @@ internal static (string Condition, List Parameters) GenerateNewFilterCon return (translator.Clause.ToString(), translator.ParameterValues); } + /// + /// Appends a schema-qualified table name. If schema is null or empty, omits the schema prefix. + /// + private static StringBuilder AppendTableName(this StringBuilder sb, string? schema, string tableName) + { + if (!string.IsNullOrEmpty(schema)) + { + sb.AppendIdentifier(schema!).Append('.'); + } + + return sb.AppendIdentifier(tableName); + } + #pragma warning disable CS0618 // VectorSearchFilter is obsolete internal static (string Clause, List Parameters) GenerateLegacyFilterWhereClause(CollectionModel model, VectorSearchFilter legacyFilter, int startParamIndex) { @@ -739,7 +753,7 @@ internal static (string Condition, List Parameters) GenerateLegacyFilter /// Builds a hybrid search command that combines vector similarity search with full-text keyword search using RRF (Reciprocal Rank Fusion). /// internal static void BuildHybridSearchCommand( - NpgsqlCommand command, string schema, string tableName, CollectionModel model, + NpgsqlCommand command, string? schema, string tableName, CollectionModel model, VectorPropertyModel vectorProperty, DataPropertyModel textProperty, object vectorValue, ICollection keywords, VectorSearchFilter? legacyFilter, Expression>? newFilter, @@ -773,7 +787,7 @@ internal static void BuildHybridSearchCommand( // Build the full table name var fullTableName = new StringBuilder() - .AppendIdentifier(schema).Append('.').AppendIdentifier(tableName) + .AppendTableName(schema, tableName) .ToString(); // Use a larger internal limit for the CTEs to get better ranking, then limit final results diff --git a/dotnet/src/VectorData/PgVector/PostgresVectorStore.cs b/dotnet/src/VectorData/PgVector/PostgresVectorStore.cs index 33ac82f946bc..fc2d601f2aa2 100644 --- a/dotnet/src/VectorData/PgVector/PostgresVectorStore.cs +++ b/dotnet/src/VectorData/PgVector/PostgresVectorStore.cs @@ -30,7 +30,7 @@ public sealed class PostgresVectorStore : VectorStore private readonly string _databaseName; /// The database schema. - private readonly string _schema; + private readonly string? _schema; private readonly IEmbeddingGenerator? _embeddingGenerator; @@ -44,7 +44,7 @@ public PostgresVectorStore(NpgsqlDataSource dataSource, bool ownsDataSource, Pos { Verify.NotNull(dataSource); - this._schema = options?.Schema ?? PostgresVectorStoreOptions.Default.Schema; + this._schema = options?.Schema; this._embeddingGenerator = options?.EmbeddingGenerator; this._dataSource = dataSource; this._dataSourceArc = ownsDataSource ? new NpgsqlDataSourceArc(dataSource) : null; diff --git a/dotnet/src/VectorData/PgVector/PostgresVectorStoreOptions.cs b/dotnet/src/VectorData/PgVector/PostgresVectorStoreOptions.cs index 445d6fa75a1a..8a09e8e8ac5e 100644 --- a/dotnet/src/VectorData/PgVector/PostgresVectorStoreOptions.cs +++ b/dotnet/src/VectorData/PgVector/PostgresVectorStoreOptions.cs @@ -20,14 +20,14 @@ public PostgresVectorStoreOptions() internal PostgresVectorStoreOptions(PostgresVectorStoreOptions? source) { - this.Schema = source?.Schema ?? Default.Schema; + this.Schema = source?.Schema; this.EmbeddingGenerator = source?.EmbeddingGenerator; } /// /// Gets or sets the database schema. /// - public string Schema { get; set; } = "public"; + public string? Schema { get; set; } /// /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. diff --git a/dotnet/test/VectorData/PgVector.UnitTests/PostgresSqlBuilderTests.cs b/dotnet/test/VectorData/PgVector.UnitTests/PostgresSqlBuilderTests.cs index c5667361f928..aa8de7d2eb7b 100644 --- a/dotnet/test/VectorData/PgVector.UnitTests/PostgresSqlBuilderTests.cs +++ b/dotnet/test/VectorData/PgVector.UnitTests/PostgresSqlBuilderTests.cs @@ -52,10 +52,11 @@ public void TestBuildCreateTableCommand(bool ifNotExists) var model = new PostgresModelBuilder().BuildDynamic(recordDefinition, defaultEmbeddingGenerator: null); - var sql = PostgresSqlBuilder.BuildCreateTableSql("public", "testcollection", model, pgVersion: new Version(18, 0), ifNotExists: ifNotExists); + var sql = PostgresSqlBuilder.BuildCreateTableSql(schema: null, "testcollection", model, pgVersion: new Version(18, 0), ifNotExists: ifNotExists); // Check for expected properties; integration tests will validate the actual SQL. - Assert.Contains("\"public\".\"testcollection\" (", sql); + Assert.Contains("\"testcollection\" (", sql); + Assert.DoesNotContain("\"public\"", sql); Assert.Contains("\"name\" TEXT", sql); Assert.Contains("\"code\" INTEGER NOT NULL", sql); Assert.Contains("\"rating\" REAL", sql); @@ -92,7 +93,7 @@ public void TestBuildCreateTableCommand_WithTimestampStoreType() var model = new PostgresModelBuilder().BuildDynamic(recordDefinition, defaultEmbeddingGenerator: null); - var sql = PostgresSqlBuilder.BuildCreateTableSql("public", "testcollection", model, pgVersion: new Version(18, 0)); + var sql = PostgresSqlBuilder.BuildCreateTableSql(schema: null, "testcollection", model, pgVersion: new Version(18, 0)); Assert.Contains("\"created_utc\" TIMESTAMPTZ NOT NULL", sql); Assert.Contains("\"created_local\" TIMESTAMP NOT NULL", sql); @@ -118,12 +119,12 @@ public void TestBuildCreateIndexCommand(string indexKind, string distanceFunctio if (indexKind != IndexKind.Hnsw) { - Assert.Throws(() => PostgresSqlBuilder.BuildCreateIndexSql("public", "testcollection", vectorColumn, indexKind, distanceFunction, isVector: true, isFullText: false, fullTextLanguage: null, ifNotExists)); - Assert.Throws(() => PostgresSqlBuilder.BuildCreateIndexSql("public", "testcollection", vectorColumn, indexKind, distanceFunction, isVector: true, isFullText: false, fullTextLanguage: null, ifNotExists)); + Assert.Throws(() => PostgresSqlBuilder.BuildCreateIndexSql(schema: null, "testcollection", vectorColumn, indexKind, distanceFunction, isVector: true, isFullText: false, fullTextLanguage: null, ifNotExists)); + Assert.Throws(() => PostgresSqlBuilder.BuildCreateIndexSql(schema: null, "testcollection", vectorColumn, indexKind, distanceFunction, isVector: true, isFullText: false, fullTextLanguage: null, ifNotExists)); return; } - var sql = PostgresSqlBuilder.BuildCreateIndexSql("public", "1testcollection", vectorColumn, indexKind, distanceFunction, isVector: true, isFullText: false, fullTextLanguage: null, ifNotExists); + var sql = PostgresSqlBuilder.BuildCreateIndexSql(schema: null, "1testcollection", vectorColumn, indexKind, distanceFunction, isVector: true, isFullText: false, fullTextLanguage: null, ifNotExists); // Check for expected properties; integration tests will validate the actual SQL. Assert.Contains("CREATE INDEX ", sql); @@ -139,7 +140,7 @@ public void TestBuildCreateIndexCommand(string indexKind, string distanceFunctio // Make sure the name is escaped, so names starting with a digit are OK. Assert.Contains($"\"1testcollection_{vectorColumn}_index\"", sql); - Assert.Contains("ON \"public\".\"1testcollection\" USING hnsw (\"embedding1\" ", sql); + Assert.Contains("ON \"1testcollection\" USING hnsw (\"embedding1\" ", sql); if (distanceFunction == null) { // Check for distance function defaults to cosine distance @@ -203,10 +204,11 @@ public void TestBuildCreateFullTextIndexCommand_EscapesSingleQuotes() public void TestBuildDropTableCommand() { using var command = new NpgsqlCommand(); - PostgresSqlBuilder.BuildDropTableCommand(command, "public", "testcollection"); + PostgresSqlBuilder.BuildDropTableCommand(command, schema: null, "testcollection"); // Check for expected properties; integration tests will validate the actual SQL. - Assert.Contains("DROP TABLE IF EXISTS \"public\".\"testcollection\"", command.CommandText); + Assert.Contains("DROP TABLE IF EXISTS \"testcollection\"", command.CommandText); + Assert.DoesNotContain("\"public\"", command.CommandText); // Output this._output.WriteLine(command.CommandText); @@ -253,12 +255,12 @@ public void TestBuildUpsertCommand() }; using var batch = new NpgsqlBatch(); - _ = PostgresSqlBuilder.BuildUpsertCommand(batch, "public", "testcollection", model, [record], generatedEmbeddings: null); + _ = PostgresSqlBuilder.BuildUpsertCommand(batch, schema: null, "testcollection", model, [record], generatedEmbeddings: null); // Check for expected properties; integration tests will validate the actual SQL. Assert.Single(batch.BatchCommands); var command = batch.BatchCommands[0]; - Assert.Contains("INSERT INTO \"public\".\"testcollection\" (", command.CommandText); + Assert.Contains("INSERT INTO \"testcollection\" (", command.CommandText); Assert.Contains("ON CONFLICT (\"id\")", command.CommandText); Assert.Contains("DO UPDATE SET", command.CommandText); Assert.NotNull(command.Parameters); @@ -316,13 +318,13 @@ public void TestBuildGetCommand() // Act using var command = new NpgsqlCommand(); - PostgresSqlBuilder.BuildGetCommand(command, "public", "testcollection", model, key, includeVectors: true); + PostgresSqlBuilder.BuildGetCommand(command, schema: null, "testcollection", model, key, includeVectors: true); // Assert Assert.Contains("SELECT", command.CommandText); Assert.Contains("\"free_parking\"", command.CommandText); Assert.Contains("\"embedding1\"", command.CommandText); - Assert.Contains("FROM \"public\".\"testcollection\"", command.CommandText); + Assert.Contains("FROM \"testcollection\"", command.CommandText); Assert.Contains("WHERE \"id\" = $1", command.CommandText); // Output @@ -360,13 +362,13 @@ public void TestBuildGetBatchCommand() // Act using var command = new NpgsqlCommand(); - PostgresSqlBuilder.BuildGetBatchCommand(command, "public", "testcollection", model, keys, includeVectors: true); + PostgresSqlBuilder.BuildGetBatchCommand(command, schema: null, "testcollection", model, keys, includeVectors: true); // Assert Assert.Contains("SELECT", command.CommandText); Assert.Contains("\"code\"", command.CommandText); Assert.Contains("\"free_parking\"", command.CommandText); - Assert.Contains("FROM \"public\".\"testcollection\"", command.CommandText); + Assert.Contains("FROM \"testcollection\"", command.CommandText); Assert.Contains("WHERE \"id\" = ANY($1)", command.CommandText); Assert.NotNull(command.Parameters); Assert.Single(command.Parameters); @@ -384,11 +386,11 @@ public void TestBuildDeleteCommand() // Act using var command = new NpgsqlCommand(); - PostgresSqlBuilder.BuildDeleteCommand(command, "public", "testcollection", "id", key); + PostgresSqlBuilder.BuildDeleteCommand(command, schema: null, "testcollection", "id", key); // Assert Assert.Contains("DELETE", command.CommandText); - Assert.Contains("FROM \"public\".\"testcollection\"", command.CommandText); + Assert.Contains("FROM \"testcollection\"", command.CommandText); Assert.Contains("WHERE \"id\" = $1", command.CommandText); // Output @@ -404,11 +406,11 @@ public void TestBuildDeleteBatchCommand() // Act using var command = new NpgsqlCommand(); var keyProperty = new KeyPropertyModel("id", typeof(long)); - PostgresSqlBuilder.BuildDeleteBatchCommand(command, "public", "testcollection", keyProperty, keys); + PostgresSqlBuilder.BuildDeleteBatchCommand(command, schema: null, "testcollection", keyProperty, keys); // Assert Assert.Contains("DELETE", command.CommandText); - Assert.Contains("FROM \"public\".\"testcollection\"", command.CommandText); + Assert.Contains("FROM \"testcollection\"", command.CommandText); Assert.Contains("WHERE \"id\" = ANY($1)", command.CommandText); Assert.NotNull(command.Parameters); Assert.Single(command.Parameters); @@ -417,4 +419,179 @@ public void TestBuildDeleteBatchCommand() // Output this._output.WriteLine(command.CommandText); } + + #region Schema-specified tests + + [Theory] + [InlineData(null, "public")] + [InlineData("myschema", "myschema")] + public void TestBuildDoesTableExistCommand(string? schema, string expectedSchema) + { + using var command = new NpgsqlCommand(); + PostgresSqlBuilder.BuildDoesTableExistCommand(command, schema, "testcollection"); + + Assert.Contains("table_schema = $1", command.CommandText); + Assert.Contains("table_name = $2", command.CommandText); + Assert.Equal(2, command.Parameters.Count); + Assert.Equal(expectedSchema, command.Parameters[0].Value); + Assert.Equal("testcollection", command.Parameters[1].Value); + + this._output.WriteLine(command.CommandText); + } + + [Theory] + [InlineData(null, "public")] + [InlineData("myschema", "myschema")] + public void TestBuildGetTablesCommand(string? schema, string expectedSchema) + { + using var command = new NpgsqlCommand(); + PostgresSqlBuilder.BuildGetTablesCommand(command, schema); + + Assert.Contains("table_schema = $1", command.CommandText); + Assert.Single(command.Parameters); + Assert.Equal(expectedSchema, command.Parameters[0].Value); + + this._output.WriteLine(command.CommandText); + } + + [Fact] + public void TestBuildCreateTableCommand_WithSchema() + { + var recordDefinition = new VectorStoreCollectionDefinition() + { + Properties = [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)), + new VectorStoreVectorProperty("embedding1", typeof(ReadOnlyMemory), 10) { IndexKind = "hnsw" } + ] + }; + + var model = new PostgresModelBuilder().BuildDynamic(recordDefinition, defaultEmbeddingGenerator: null); + + var sql = PostgresSqlBuilder.BuildCreateTableSql(schema: "myschema", "testcollection", model, pgVersion: new Version(18, 0)); + + Assert.Contains("\"myschema\".\"testcollection\"", sql); + + this._output.WriteLine(sql); + } + + [Fact] + public void TestBuildCreateIndexCommand_WithSchema() + { + var sql = PostgresSqlBuilder.BuildCreateIndexSql("myschema", "testcollection", "embedding1", IndexKind.Hnsw, DistanceFunction.CosineDistance, isVector: true, isFullText: false, fullTextLanguage: null, ifNotExists: true); + + Assert.Contains("ON \"myschema\".\"testcollection\"", sql); + + this._output.WriteLine(sql); + } + + [Fact] + public void TestBuildDropTableCommand_WithSchema() + { + using var command = new NpgsqlCommand(); + PostgresSqlBuilder.BuildDropTableCommand(command, schema: "myschema", "testcollection"); + + Assert.Contains("DROP TABLE IF EXISTS \"myschema\".\"testcollection\"", command.CommandText); + + this._output.WriteLine(command.CommandText); + } + + [Fact] + public void TestBuildUpsertCommand_WithSchema() + { + var recordDefinition = new VectorStoreCollectionDefinition() + { + Properties = [ + new VectorStoreKeyProperty("id", typeof(int)), + new VectorStoreDataProperty("name", typeof(string)), + new VectorStoreVectorProperty("embedding1", typeof(ReadOnlyMemory), 10) { IndexKind = "hnsw" } + ] + }; + + var model = new PostgresModelBuilder().BuildDynamic(recordDefinition, defaultEmbeddingGenerator: null); + + var record = new Dictionary + { + ["id"] = 1, + ["name"] = "Test", + ["embedding1"] = new ReadOnlyMemory(s_vector), + }; + + using var batch = new NpgsqlBatch(); + _ = PostgresSqlBuilder.BuildUpsertCommand(batch, schema: "myschema", "testcollection", model, [record], generatedEmbeddings: null); + + var command = batch.BatchCommands[0]; + Assert.Contains("INSERT INTO \"myschema\".\"testcollection\"", command.CommandText); + + this._output.WriteLine(command.CommandText); + } + + [Fact] + public void TestBuildGetCommand_WithSchema() + { + var recordDefinition = new VectorStoreCollectionDefinition() + { + Properties = [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)), + new VectorStoreVectorProperty("embedding1", typeof(ReadOnlyMemory), 10) { IndexKind = "hnsw" } + ] + }; + + var model = new PostgresModelBuilder().BuildDynamic(recordDefinition, defaultEmbeddingGenerator: null); + + using var command = new NpgsqlCommand(); + PostgresSqlBuilder.BuildGetCommand(command, schema: "myschema", "testcollection", model, 123, includeVectors: true); + + Assert.Contains("FROM \"myschema\".\"testcollection\"", command.CommandText); + + this._output.WriteLine(command.CommandText); + } + + [Fact] + public void TestBuildGetBatchCommand_WithSchema() + { + var recordDefinition = new VectorStoreCollectionDefinition() + { + Properties = [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)), + new VectorStoreVectorProperty("embedding1", typeof(ReadOnlyMemory), 10) { IndexKind = "hnsw" } + ] + }; + + var model = new PostgresModelBuilder().BuildDynamic(recordDefinition, defaultEmbeddingGenerator: null); + + using var command = new NpgsqlCommand(); + PostgresSqlBuilder.BuildGetBatchCommand(command, schema: "myschema", "testcollection", model, new List { 1, 2 }, includeVectors: true); + + Assert.Contains("FROM \"myschema\".\"testcollection\"", command.CommandText); + + this._output.WriteLine(command.CommandText); + } + + [Fact] + public void TestBuildDeleteCommand_WithSchema() + { + using var command = new NpgsqlCommand(); + PostgresSqlBuilder.BuildDeleteCommand(command, schema: "myschema", "testcollection", "id", 123); + + Assert.Contains("FROM \"myschema\".\"testcollection\"", command.CommandText); + + this._output.WriteLine(command.CommandText); + } + + [Fact] + public void TestBuildDeleteBatchCommand_WithSchema() + { + using var command = new NpgsqlCommand(); + var keyProperty = new KeyPropertyModel("id", typeof(long)); + PostgresSqlBuilder.BuildDeleteBatchCommand(command, schema: "myschema", "testcollection", keyProperty, new List { 1, 2 }); + + Assert.Contains("FROM \"myschema\".\"testcollection\"", command.CommandText); + + this._output.WriteLine(command.CommandText); + } + + #endregion }