Skip to content

Commit 7244a28

Browse files
committed
Implement latest SQL Server approximate vector search feature
Part of #36384
1 parent 6abde55 commit 7244a28

8 files changed

Lines changed: 61 additions & 42 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,6 @@ ASALocalRun/
342342
# Local History for Visual Studio
343343
.localhistory/
344344
full.targets.txt
345+
346+
# Language Server cache
347+
*.lscache

Directory.Packages.props

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
<PackageVersion Include="Microsoft.Azure.Cosmos" Version="3.58.0" />
3535
<!-- SQL Server dependencies -->
3636
<PackageVersion Include="Microsoft.Data.SqlClient" Version="7.0.0" />
37+
<PackageVersion Include="Microsoft.Data.SqlClient.Extensions.Azure" Version="1.0.0" />
3738
<PackageVersion Include="Microsoft.SqlServer.Types" Version="170.1000.7" />
3839
<!-- external dependencies -->
3940
<PackageVersion Include="Castle.Core" Version="5.2.1" />

src/EFCore.SqlServer/Extensions/SqlServerQueryableExtensions.cs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,16 @@ public static class SqlServerQueryableExtensions
2828
/// An ANN (Approximate Nearest Neighbor) index is used only if a matching ANN index, with the same metric and on the same column,
2929
/// is found. If there are no compatible ANN indexes, a warning is raised and the KNN (k-Nearest Neighbor) algorithm is used.
3030
/// </param>
31-
/// <param name="topN">The maximum number of similar vectors that must be returned. It must be a positive integer.</param>
31+
/// <remarks>
32+
/// Compose the returned query with <c>OrderBy(r => r.Distance)</c> and <c>Take(...)</c> to limit the results as required
33+
/// for approximate vector search. For example:
34+
/// <code>
35+
/// var results = context.Set&lt;Blog&gt;()
36+
/// .VectorSearch(b => b.Embedding, embedding, "cosine")
37+
/// .OrderBy(r => r.Distance)
38+
/// .Take(10);
39+
/// </code>
40+
/// </remarks>
3241
/// <seealso href="https://learn.microsoft.com/sql/t-sql/functions/vector-search-transact-sql">
3342
/// SQL Server documentation for <c>VECTOR_SEARCH()</c>.
3443
/// </seealso>
@@ -38,8 +47,7 @@ public static IQueryable<VectorSearchResult<T>> VectorSearch<T, TVector>(
3847
this DbSet<T> source,
3948
Expression<Func<T, TVector>> vectorPropertySelector,
4049
TVector similarTo,
41-
[NotParameterized] string metric,
42-
int topN)
50+
[NotParameterized] string metric)
4351
where T : class
4452
where TVector : unmanaged
4553
{
@@ -50,12 +58,11 @@ public static IQueryable<VectorSearchResult<T>> VectorSearch<T, TVector>(
5058
? queryableSource.Provider.CreateQuery<VectorSearchResult<T>>(
5159
Expression.Call(
5260
// Note that the method used is the one below, accepting IQueryable<T>, not DbSet<T>
53-
method: new Func<IQueryable<T>, Expression<Func<T, TVector>>, TVector, string, int, IQueryable<VectorSearchResult<T>>>(VectorSearch).Method,
61+
method: new Func<IQueryable<T>, Expression<Func<T, TVector>>, TVector, string, IQueryable<VectorSearchResult<T>>>(VectorSearch).Method,
5462
root,
5563
Expression.Quote(vectorPropertySelector),
5664
Expression.Constant(similarTo),
57-
Expression.Constant(metric),
58-
Expression.Constant(topN)))
65+
Expression.Constant(metric)))
5966
: throw new InvalidOperationException(CoreStrings.FunctionOnNonEfLinqProvider(nameof(VectorSearch)));
6067
}
6168

@@ -67,8 +74,7 @@ private static IQueryable<VectorSearchResult<T>> VectorSearch<T, TVector>(
6774
this IQueryable<T> source,
6875
Expression<Func<T, TVector>> vectorPropertySelector,
6976
TVector similarTo,
70-
[NotParameterized] string metric,
71-
int topN)
77+
[NotParameterized] string metric)
7278
where T : class
7379
where TVector : unmanaged
7480
=> throw new UnreachableException();

src/EFCore.SqlServer/Extensions/VectorSearchResult.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace Microsoft.EntityFrameworkCore;
77

88
/// <summary>
99
/// Represents the results from a call to
10-
/// <see cref="SqlServerQueryableExtensions.VectorSearch{T, TVector}(DbSet{T}, Expression{Func{T, TVector}}, TVector, string, int)" />.
10+
/// <see cref="SqlServerQueryableExtensions.VectorSearch{T, TVector}(DbSet{T}, Expression{Func{T, TVector}}, TVector, string)" />.
1111
/// </summary>
1212
[Experimental(EFDiagnostics.SqlServerVectorSearch)]
1313
public readonly struct VectorSearchResult<T>(T value, double distance)

src/EFCore.SqlServer/Query/Internal/SqlServerQuerySqlGenerator.cs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,16 +155,14 @@ protected override Expression VisitTableValuedFunction(TableValuedFunctionExpres
155155
TableExpression table,
156156
ColumnExpression column,
157157
SqlExpression similarTo,
158-
SqlConstantExpression { Value: string } metric,
159-
SqlExpression topN
158+
SqlConstantExpression { Value: string } metric
160159
]
161160
}:
162161
// VECTOR_SEARCH(
163162
// TABLE = [Articles] AS t,
164163
// COLUMN = [Vector],
165164
// SIMILAR_TO = @qv,
166-
// METRIC = 'Cosine',
167-
// TOP_N = 3
165+
// METRIC = 'Cosine'
168166
// )
169167
Sql.AppendLine("VECTOR_SEARCH(");
170168

@@ -185,10 +183,6 @@ SqlExpression topN
185183

186184
Sql.Append("METRIC = ");
187185
Visit(metric);
188-
Sql.AppendLine(",");
189-
190-
Sql.Append("TOP_N = ");
191-
Visit(topN);
192186
Sql.AppendLine();
193187
}
194188

@@ -558,6 +552,13 @@ protected override void GenerateTop(SelectExpression selectExpression)
558552
Visit(selectExpression.Limit);
559553

560554
Sql.Append(") ");
555+
556+
// When performing approximate vector search with VECTOR_SEARCH(), SQL Server requires adding
557+
// WITH APPROXIMATE: https://learn.microsoft.com/sql/t-sql/functions/vector-search-transact-sql
558+
if (selectExpression.Tables.Any(t => t.UnwrapJoin() is TableValuedFunctionExpression { Name: "VECTOR_SEARCH" })
559+
{
560+
Sql.Append("WITH APPROXIMATE ");
561+
}
561562
}
562563

563564
_withinTable = parentWithinTable;

src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ when methodCallExpression.Arguments is
9191
_, // source, translated above
9292
UnaryExpression { NodeType: ExpressionType.Quote, Operand: LambdaExpression vectorPropertySelector },
9393
var similarTo,
94-
var metric,
95-
var topN
94+
var metric
9695
]
9796
&& source is
9897
{
@@ -113,8 +112,7 @@ var topN
113112
}
114113

115114
if (TranslateExpression(similarTo) is not { } translatedSimilarTo
116-
|| TranslateExpression(metric, applyDefaultTypeMapping: false) is not { } translatedMetric
117-
|| TranslateExpression(topN) is not { } translatedTopN)
115+
|| TranslateExpression(metric, applyDefaultTypeMapping: false) is not { } translatedMetric)
118116
{
119117
return QueryCompilationContext.NotTranslatedExpression;
120118
}
@@ -135,8 +133,7 @@ var topN
135133
// as required by SQL Server)
136134
vectorColumn,
137135
translatedSimilarTo,
138-
translatedMetric,
139-
translatedTopN
136+
translatedMetric
140137
]);
141138

142139
// We have the VECTOR_SEARCH() function call. Modify the SelectExpression and shaper to use it and project

test/EFCore.SqlServer.FunctionalTests/EFCore.SqlServer.FunctionalTests.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,6 @@
7777
<ItemGroup>
7878
<PackageReference Include="Microsoft.Extensions.Configuration.EnvironmentVariables" />
7979
<PackageReference Include="Microsoft.Extensions.Configuration.Json" />
80+
<PackageReference Include="Microsoft.Data.SqlClient.Extensions.Azure" />
8081
</ItemGroup>
8182
</Project>

test/EFCore.SqlServer.FunctionalTests/Query/Translations/VectorTranslationsSqlServerTest.cs

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ ORDER BY VECTOR_DISTANCE('cosine', [v].[Vector], CAST('[1,2,100]' AS VECTOR(3)))
6565
""");
6666
}
6767

68+
// The latest vector index version (required for VECTOR_SEARCH) is only available on Azure SQL (#36384).
6869
[ConditionalFact]
70+
[SqlServerCondition(SqlServerCondition.IsAzureSql)]
6971
[Experimental("EF9105")]
7072
public async Task VectorSearch_project_entity_and_distance()
7173
{
@@ -74,28 +76,32 @@ public async Task VectorSearch_project_entity_and_distance()
7476
var vector = new SqlVector<float>(new float[] { 1, 2, 100 });
7577

7678
var results = await ctx.VectorEntities
77-
.VectorSearch(e => e.Vector, similarTo: vector, "cosine", topN: 1)
79+
.VectorSearch(e => e.Vector, similarTo: vector, "cosine")
80+
.OrderBy(e => e.Distance)
81+
.Take(1)
7882
.ToListAsync();
7983

8084
Assert.Equal(2, results.Single().Value.Id);
8185

8286
AssertSql(
8387
"""
84-
@p='Microsoft.Data.SqlTypes.SqlVector`1[System.Single]' (Size = 20) (DbType = Binary)
8588
@p1='1'
89+
@p='Microsoft.Data.SqlTypes.SqlVector`1[System.Single]' (Size = 20) (DbType = Binary)
8690
87-
SELECT [v].[Id], [v0].[Distance]
91+
SELECT TOP(@p1) WITH APPROXIMATE [v].[Id], [v0].[Distance]
8892
FROM VECTOR_SEARCH(
8993
TABLE = [VectorEntities] AS [v],
9094
COLUMN = [Vector],
9195
SIMILAR_TO = @p,
92-
METRIC = 'cosine',
93-
TOP_N = @p1
96+
METRIC = 'cosine'
9497
) AS [v0]
98+
ORDER BY [v0].[Distance]
9599
""");
96100
}
97101

102+
// The latest vector index version (required for VECTOR_SEARCH) is only available on Azure SQL (#36384).
98103
[ConditionalFact]
104+
[SqlServerCondition(SqlServerCondition.IsAzureSql)]
99105
[Experimental("EF9105")]
100106
public async Task VectorSearch_project_entity_only_with_distance_filter_and_ordering()
101107
{
@@ -104,10 +110,11 @@ public async Task VectorSearch_project_entity_only_with_distance_filter_and_orde
104110
var vector = new SqlVector<float>(new float[] { 1, 2, 100 });
105111

106112
var results = await ctx.VectorEntities
107-
.VectorSearch(e => e.Vector, similarTo: vector, "cosine", topN: 3)
113+
.VectorSearch(e => e.Vector, similarTo: vector, "cosine")
108114
.Where(e => e.Distance < 0.01)
109115
.OrderBy(e => e.Distance)
110116
.Select(e => e.Value)
117+
.Take(3)
111118
.ToListAsync();
112119

113120
Assert.Collection(
@@ -117,16 +124,15 @@ public async Task VectorSearch_project_entity_only_with_distance_filter_and_orde
117124

118125
AssertSql(
119126
"""
120-
@p='Microsoft.Data.SqlTypes.SqlVector`1[System.Single]' (Size = 20) (DbType = Binary)
121127
@p1='3'
128+
@p='Microsoft.Data.SqlTypes.SqlVector`1[System.Single]' (Size = 20) (DbType = Binary)
122129
123-
SELECT [v].[Id]
130+
SELECT TOP(@p1) WITH APPROXIMATE [v].[Id]
124131
FROM VECTOR_SEARCH(
125132
TABLE = [VectorEntities] AS [v],
126133
COLUMN = [Vector],
127134
SIMILAR_TO = @p,
128-
METRIC = 'cosine',
129-
TOP_N = @p1
135+
METRIC = 'cosine'
130136
) AS [v0]
131137
WHERE [v0].[Distance] < 0.01E0
132138
ORDER BY [v0].[Distance]
@@ -167,23 +173,27 @@ public class VectorQueryContext(DbContextOptions options) : PoolableDbContext(op
167173

168174
public static async Task SeedAsync(VectorQueryContext context)
169175
{
170-
var vectorEntities = new VectorEntity[]
171-
{
172-
new() { Id = 1, Vector = new SqlVector<float>(new float[] { 1, 2, 3 }) },
173-
new() { Id = 2, Vector = new SqlVector<float>(new float[] { 1, 2, 100 }) },
174-
new() { Id = 3, Vector = new SqlVector<float>(new float[] { 1, 2, 1000 }) }
175-
};
176+
// SQL Server vector indexes require at least 100 rows.
177+
var vectorEntities = Enumerable.Range(1, 100).Select(
178+
i => new VectorEntity
179+
{
180+
Id = i,
181+
Vector = new SqlVector<float>(new float[] { i * 0.01f, i * 0.02f, i * 0.03f })
182+
}).ToList();
183+
184+
// Override specific rows we use in test assertions
185+
vectorEntities[0] = new VectorEntity { Id = 1, Vector = new SqlVector<float>(new float[] { 1, 2, 3 }) };
186+
vectorEntities[1] = new VectorEntity { Id = 2, Vector = new SqlVector<float>(new float[] { 1, 2, 100 }) };
187+
vectorEntities[2] = new VectorEntity { Id = 3, Vector = new SqlVector<float>(new float[] { 1, 2, 1000 }) };
176188

177189
context.VectorEntities.AddRange(vectorEntities);
178190
await context.SaveChangesAsync();
179191

180-
// TODO (#36384): Remove this once it's out of preview
181192
await context.Database.ExecuteSqlAsync($"ALTER DATABASE SCOPED CONFIGURATION SET PREVIEW_FEATURES = ON");
182193

183194
await context.Database.ExecuteSqlAsync($"""
184195
CREATE VECTOR INDEX vec_idx ON VectorEntities(Vector)
185-
WITH (METRIC = 'Cosine', TYPE = 'DiskANN')
186-
ON [PRIMARY];
196+
WITH (METRIC = 'Cosine', TYPE = 'DiskANN');
187197
""");
188198
}
189199
}

0 commit comments

Comments
 (0)