77 "os/signal"
88 "runtime"
99 "sort"
10- "strings"
11- "sync"
1210 "syscall"
1311 "time"
1412
@@ -151,24 +149,38 @@ func (r *MigrateFromQdrantCmd) prepareTargetCollection(ctx context.Context, sour
151149 fmt .Print ("\n " )
152150 pterm .Info .Printfln ("Target collection '%s' already exists. Skipping creation." , targetCollection )
153151 } else {
154- params := sourceCollectionInfo .Config .GetParams ()
152+ config := sourceCollectionInfo .GetConfig ()
153+ params := config .GetParams ()
155154 if err := targetClient .CreateCollection (ctx , & qdrant.CreateCollection {
156155 CollectionName : targetCollection ,
157- HnswConfig : sourceCollectionInfo .Config .GetHnswConfig (),
158- WalConfig : sourceCollectionInfo .Config .GetWalConfig (),
159- OptimizersConfig : sourceCollectionInfo .Config .GetOptimizerConfig (),
160- ShardNumber : & params .ShardNumber ,
161- OnDiskPayload : & params .OnDiskPayload ,
162- VectorsConfig : params .VectorsConfig ,
163- ReplicationFactor : params .ReplicationFactor ,
164- WriteConsistencyFactor : params .WriteConsistencyFactor ,
165- QuantizationConfig : sourceCollectionInfo .Config .GetQuantizationConfig (),
166- ShardingMethod : params .ShardingMethod ,
167- SparseVectorsConfig : params .SparseVectorsConfig ,
168- StrictModeConfig : sourceCollectionInfo .Config .GetStrictModeConfig (),
156+ HnswConfig : config .GetHnswConfig (),
157+ WalConfig : config .GetWalConfig (),
158+ OptimizersConfig : config .GetOptimizerConfig (),
159+ ShardNumber : qdrant .PtrOf (params .GetShardNumber ()),
160+ OnDiskPayload : qdrant .PtrOf (params .GetOnDiskPayload ()),
161+ VectorsConfig : params .GetVectorsConfig (),
162+ ReplicationFactor : qdrant .PtrOf (params .GetReplicationFactor ()),
163+ WriteConsistencyFactor : qdrant .PtrOf (params .GetWriteConsistencyFactor ()),
164+ QuantizationConfig : config .GetQuantizationConfig (),
165+ ShardingMethod : params .GetShardingMethod ().Enum (),
166+ SparseVectorsConfig : params .GetSparseVectorsConfig (),
167+ StrictModeConfig : config .GetStrictModeConfig (),
168+ Metadata : config .GetMetadata (),
169169 }); err != nil {
170170 return fmt .Errorf ("failed to create target collection: %w" , err )
171171 }
172+
173+ if params .GetShardingMethod () == qdrant .ShardingMethod_Custom {
174+ shardKeys , err := sourceClient .ListShardKeys (ctx , sourceCollection )
175+ if err != nil {
176+ return fmt .Errorf ("failed to list shard keys: %w" , err )
177+ }
178+ for _ , shardKey := range shardKeys {
179+ if err := targetClient .CreateShardKey (ctx , targetCollection , & qdrant.CreateShardKey {ShardKey : shardKey .GetKey ()}); err != nil {
180+ return fmt .Errorf ("failed to create shard key: %w" , err )
181+ }
182+ }
183+ }
172184 }
173185 }
174186
@@ -322,7 +334,7 @@ func (r *MigrateFromQdrantCmd) samplePointIDs(ctx context.Context, client *qdran
322334
323335// processBatch handles the upserting of a batch of points to the target collection.
324336// It deals with sharding by creating shard keys if they don't exist and retries on transient errors.
325- func (r * MigrateFromQdrantCmd ) processBatch (ctx context.Context , points []* qdrant.RetrievedPoint , targetClient * qdrant.Client , targetCollection string , shardKeys * sync. Map , wait bool ) error {
337+ func (r * MigrateFromQdrantCmd ) processBatch (ctx context.Context , points []* qdrant.RetrievedPoint , targetClient * qdrant.Client , targetCollection string , wait bool ) error {
326338 // Group points by their shard key.
327339 byShardKey := make (map [string ][]* qdrant.PointStruct )
328340 shardKeyObjs := make (map [string ]* qdrant.ShardKey )
@@ -353,14 +365,6 @@ func (r *MigrateFromQdrantCmd) processBatch(ctx context.Context, points []*qdran
353365 Wait : qdrant .PtrOf (wait ),
354366 }
355367 if key != "" {
356- // If the shard key is new, create it on the target collection.
357- if _ , ok := shardKeys .Load (key ); ! ok {
358- err := targetClient .CreateShardKey (ctx , targetCollection , & qdrant.CreateShardKey {ShardKey : shardKeyObjs [key ]})
359- if err != nil && ! strings .Contains (err .Error (), "already exists" ) {
360- return fmt .Errorf ("failed to create shard key %s: %w" , key , err )
361- }
362- shardKeys .Store (key , true )
363- }
364368 // Specify the shard key for the upsert request.
365369 req .ShardKeySelector = & qdrant.ShardKeySelector {ShardKeys : []* qdrant.ShardKey {shardKeyObjs [key ]}}
366370 }
@@ -395,7 +399,6 @@ func (r *MigrateFromQdrantCmd) migrateDataSequential(ctx context.Context, source
395399
396400 bar , _ := pterm .DefaultProgressbar .WithTotal (int (sourcePointCount )).Start ()
397401 displayMigrationProgress (bar , count )
398- shardKeys := & sync.Map {}
399402
400403 for {
401404 // Scroll through points from the source collection in batches.
@@ -411,7 +414,7 @@ func (r *MigrateFromQdrantCmd) migrateDataSequential(ctx context.Context, source
411414 }
412415
413416 points := resp .GetResult ()
414- if err := r .processBatch (ctx , points , targetClient , targetCollection , shardKeys , true ); err != nil {
417+ if err := r .processBatch (ctx , points , targetClient , targetCollection , true ); err != nil {
415418 return err
416419 }
417420
@@ -478,15 +481,14 @@ func (r *MigrateFromQdrantCmd) migrateDataParallel(ctx context.Context, sourceCl
478481 displayMigrationProgress (bar , totalProcessed )
479482
480483 // Use a semaphore to limit the number of concurrent workers.
481- shardKeys := & sync.Map {}
482484 errs := make (chan error , len (ranges ))
483485 sem := make (chan struct {}, r .NumWorkers )
484486
485487 // Start a goroutine for each range.
486488 for _ , rg := range ranges {
487489 sem <- struct {}{}
488490 go func (rg rangeSpec ) {
489- errs <- r .migrateRange (ctx , sourceCollection , targetCollection , sourceClient , targetClient , rg , shardKeys , bar )
491+ errs <- r .migrateRange (ctx , sourceCollection , targetCollection , sourceClient , targetClient , rg , bar )
490492 <- sem
491493 }(rg )
492494 }
@@ -507,7 +509,7 @@ func (r *MigrateFromQdrantCmd) migrateDataParallel(ctx context.Context, sourceCl
507509
508510// migrateRange is the function executed by each worker in parallel migration.
509511// It scrolls through a specific range of points and upserts them to the target.
510- func (r * MigrateFromQdrantCmd ) migrateRange (ctx context.Context , sourceCollection , targetCollection string , sourceClient , targetClient * qdrant.Client , rg rangeSpec , shardKeys * sync. Map , bar * pterm.ProgressbarPrinter ) error {
512+ func (r * MigrateFromQdrantCmd ) migrateRange (ctx context.Context , sourceCollection , targetCollection string , sourceClient , targetClient * qdrant.Client , rg rangeSpec , bar * pterm.ProgressbarPrinter ) error {
511513 offsetKey := fmt .Sprintf ("%s-workers-%d-range-%d" , sourceCollection , r .NumWorkers , rg .id )
512514 offset := rg .start
513515 var count uint64
@@ -542,7 +544,7 @@ func (r *MigrateFromQdrantCmd) migrateRange(ctx context.Context, sourceCollectio
542544 }
543545 }
544546
545- if err := r .processBatch (ctx , points , targetClient , targetCollection , shardKeys , false ); err != nil {
547+ if err := r .processBatch (ctx , points , targetClient , targetCollection , false ); err != nil {
546548 return err
547549 }
548550
0 commit comments