Skip to content

Commit 0710650

Browse files
committed
refactor: Use the new listShardKeys() API
Signed-off-by: Anush008 <mail@anush.sh>
1 parent e97acca commit 0710650

2 files changed

Lines changed: 33 additions & 31 deletions

File tree

cmd/migrate_from_qdrant.go

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ import (
77
"os/signal"
88
"runtime"
99
"sort"
10-
"strings"
11-
"sync"
1210
"syscall"
1311
"time"
1412

@@ -151,24 +149,38 @@ func (r *MigrateFromQdrantCmd) prepareTargetCollection(ctx context.Context, sour
151149
fmt.Print("\n")
152150
pterm.Info.Printfln("Target collection '%s' already exists. Skipping creation.", targetCollection)
153151
} else {
154-
params := sourceCollectionInfo.Config.GetParams()
152+
config := sourceCollectionInfo.GetConfig()
153+
params := config.GetParams()
155154
if err := targetClient.CreateCollection(ctx, &qdrant.CreateCollection{
156155
CollectionName: targetCollection,
157-
HnswConfig: sourceCollectionInfo.Config.GetHnswConfig(),
158-
WalConfig: sourceCollectionInfo.Config.GetWalConfig(),
159-
OptimizersConfig: sourceCollectionInfo.Config.GetOptimizerConfig(),
160-
ShardNumber: &params.ShardNumber,
161-
OnDiskPayload: &params.OnDiskPayload,
162-
VectorsConfig: params.VectorsConfig,
163-
ReplicationFactor: params.ReplicationFactor,
164-
WriteConsistencyFactor: params.WriteConsistencyFactor,
165-
QuantizationConfig: sourceCollectionInfo.Config.GetQuantizationConfig(),
166-
ShardingMethod: params.ShardingMethod,
167-
SparseVectorsConfig: params.SparseVectorsConfig,
168-
StrictModeConfig: sourceCollectionInfo.Config.GetStrictModeConfig(),
156+
HnswConfig: config.GetHnswConfig(),
157+
WalConfig: config.GetWalConfig(),
158+
OptimizersConfig: config.GetOptimizerConfig(),
159+
ShardNumber: qdrant.PtrOf(params.GetShardNumber()),
160+
OnDiskPayload: qdrant.PtrOf(params.GetOnDiskPayload()),
161+
VectorsConfig: params.GetVectorsConfig(),
162+
ReplicationFactor: qdrant.PtrOf(params.GetReplicationFactor()),
163+
WriteConsistencyFactor: qdrant.PtrOf(params.GetWriteConsistencyFactor()),
164+
QuantizationConfig: config.GetQuantizationConfig(),
165+
ShardingMethod: params.GetShardingMethod().Enum(),
166+
SparseVectorsConfig: params.GetSparseVectorsConfig(),
167+
StrictModeConfig: config.GetStrictModeConfig(),
168+
Metadata: config.GetMetadata(),
169169
}); err != nil {
170170
return fmt.Errorf("failed to create target collection: %w", err)
171171
}
172+
173+
if params.GetShardingMethod() == qdrant.ShardingMethod_Custom {
174+
shardKeys, err := sourceClient.ListShardKeys(ctx, sourceCollection)
175+
if err != nil {
176+
return fmt.Errorf("failed to list shard keys: %w", err)
177+
}
178+
for _, shardKey := range shardKeys {
179+
if err := targetClient.CreateShardKey(ctx, targetCollection, &qdrant.CreateShardKey{ShardKey: shardKey.GetKey()}); err != nil {
180+
return fmt.Errorf("failed to create shard key: %w", err)
181+
}
182+
}
183+
}
172184
}
173185
}
174186

@@ -322,7 +334,7 @@ func (r *MigrateFromQdrantCmd) samplePointIDs(ctx context.Context, client *qdran
322334

323335
// processBatch handles the upserting of a batch of points to the target collection.
324336
// It deals with sharding by creating shard keys if they don't exist and retries on transient errors.
325-
func (r *MigrateFromQdrantCmd) processBatch(ctx context.Context, points []*qdrant.RetrievedPoint, targetClient *qdrant.Client, targetCollection string, shardKeys *sync.Map, wait bool) error {
337+
func (r *MigrateFromQdrantCmd) processBatch(ctx context.Context, points []*qdrant.RetrievedPoint, targetClient *qdrant.Client, targetCollection string, wait bool) error {
326338
// Group points by their shard key.
327339
byShardKey := make(map[string][]*qdrant.PointStruct)
328340
shardKeyObjs := make(map[string]*qdrant.ShardKey)
@@ -353,14 +365,6 @@ func (r *MigrateFromQdrantCmd) processBatch(ctx context.Context, points []*qdran
353365
Wait: qdrant.PtrOf(wait),
354366
}
355367
if key != "" {
356-
// If the shard key is new, create it on the target collection.
357-
if _, ok := shardKeys.Load(key); !ok {
358-
err := targetClient.CreateShardKey(ctx, targetCollection, &qdrant.CreateShardKey{ShardKey: shardKeyObjs[key]})
359-
if err != nil && !strings.Contains(err.Error(), "already exists") {
360-
return fmt.Errorf("failed to create shard key %s: %w", key, err)
361-
}
362-
shardKeys.Store(key, true)
363-
}
364368
// Specify the shard key for the upsert request.
365369
req.ShardKeySelector = &qdrant.ShardKeySelector{ShardKeys: []*qdrant.ShardKey{shardKeyObjs[key]}}
366370
}
@@ -395,7 +399,6 @@ func (r *MigrateFromQdrantCmd) migrateDataSequential(ctx context.Context, source
395399

396400
bar, _ := pterm.DefaultProgressbar.WithTotal(int(sourcePointCount)).Start()
397401
displayMigrationProgress(bar, count)
398-
shardKeys := &sync.Map{}
399402

400403
for {
401404
// Scroll through points from the source collection in batches.
@@ -411,7 +414,7 @@ func (r *MigrateFromQdrantCmd) migrateDataSequential(ctx context.Context, source
411414
}
412415

413416
points := resp.GetResult()
414-
if err := r.processBatch(ctx, points, targetClient, targetCollection, shardKeys, true); err != nil {
417+
if err := r.processBatch(ctx, points, targetClient, targetCollection, true); err != nil {
415418
return err
416419
}
417420

@@ -478,15 +481,14 @@ func (r *MigrateFromQdrantCmd) migrateDataParallel(ctx context.Context, sourceCl
478481
displayMigrationProgress(bar, totalProcessed)
479482

480483
// Use a semaphore to limit the number of concurrent workers.
481-
shardKeys := &sync.Map{}
482484
errs := make(chan error, len(ranges))
483485
sem := make(chan struct{}, r.NumWorkers)
484486

485487
// Start a goroutine for each range.
486488
for _, rg := range ranges {
487489
sem <- struct{}{}
488490
go func(rg rangeSpec) {
489-
errs <- r.migrateRange(ctx, sourceCollection, targetCollection, sourceClient, targetClient, rg, shardKeys, bar)
491+
errs <- r.migrateRange(ctx, sourceCollection, targetCollection, sourceClient, targetClient, rg, bar)
490492
<-sem
491493
}(rg)
492494
}
@@ -507,7 +509,7 @@ func (r *MigrateFromQdrantCmd) migrateDataParallel(ctx context.Context, sourceCl
507509

508510
// migrateRange is the function executed by each worker in parallel migration.
509511
// It scrolls through a specific range of points and upserts them to the target.
510-
func (r *MigrateFromQdrantCmd) migrateRange(ctx context.Context, sourceCollection, targetCollection string, sourceClient, targetClient *qdrant.Client, rg rangeSpec, shardKeys *sync.Map, bar *pterm.ProgressbarPrinter) error {
512+
func (r *MigrateFromQdrantCmd) migrateRange(ctx context.Context, sourceCollection, targetCollection string, sourceClient, targetClient *qdrant.Client, rg rangeSpec, bar *pterm.ProgressbarPrinter) error {
511513
offsetKey := fmt.Sprintf("%s-workers-%d-range-%d", sourceCollection, r.NumWorkers, rg.id)
512514
offset := rg.start
513515
var count uint64
@@ -542,7 +544,7 @@ func (r *MigrateFromQdrantCmd) migrateRange(ctx context.Context, sourceCollectio
542544
}
543545
}
544546

545-
if err := r.processBatch(ctx, points, targetClient, targetCollection, shardKeys, false); err != nil {
547+
if err := r.processBatch(ctx, points, targetClient, targetCollection, false); err != nil {
546548
return err
547549
}
548550

integration_tests/image_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
//nolint:unparam
1616
func qdrantContainer(ctx context.Context, t *testing.T, apiKey string) testcontainers.Container {
1717
req := testcontainers.ContainerRequest{
18-
Image: "qdrant/qdrant:v1.16.1",
18+
Image: "qdrant/qdrant:v1.17.0",
1919
ExposedPorts: []string{"6333/tcp", "6334/tcp"},
2020
Env: map[string]string{
2121
"QDRANT__CLUSTER__ENABLED": "true",

0 commit comments

Comments
 (0)