@@ -359,6 +359,7 @@ def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
359359 self .used_hosts_per_remote_dc = used_hosts_per_remote_dc
360360 self ._live_hosts = {}
361361 self ._dc_live_hosts = {}
362+ self ._remote_hosts = {}
362363 self ._endpoints = []
363364 self ._position = 0
364365 LoadBalancingPolicy .__init__ (self )
@@ -369,78 +370,75 @@ def _rack(self, host):
369370 def _dc (self , host ):
370371 return host .datacenter or self .local_dc
371372
373+ def _refresh_remote_hosts (self ):
374+ remote_hosts = {}
375+ if self .used_hosts_per_remote_dc > 0 :
376+ for datacenter , hosts in self ._dc_live_hosts .items ():
377+ if datacenter != self .local_dc :
378+ remote_hosts .update (dict .fromkeys (hosts [:self .used_hosts_per_remote_dc ]))
379+ self ._remote_hosts = remote_hosts
380+
372381 def populate (self , cluster , hosts ):
373382 for (dc , rack ), rack_hosts in groupby (hosts , lambda host : (self ._dc (host ), self ._rack (host ))):
374383 self ._live_hosts [(dc , rack )] = tuple ({* rack_hosts , * self ._live_hosts .get ((dc , rack ), [])})
375384 for dc , dc_hosts in groupby (hosts , lambda host : self ._dc (host )):
376385 self ._dc_live_hosts [dc ] = tuple ({* dc_hosts , * self ._dc_live_hosts .get (dc , [])})
377386
378387 self ._position = randint (0 , len (hosts ) - 1 ) if hosts else 0
388+ self ._refresh_remote_hosts ()
379389
380390 def distance (self , host ):
381- rack = self ._rack (host )
382391 dc = self ._dc (host )
383- if rack == self .local_rack and dc == self .local_dc :
384- return HostDistance .LOCAL_RACK
385-
386392 if dc == self .local_dc :
393+ if self ._rack (host ) == self .local_rack :
394+ return HostDistance .LOCAL_RACK
387395 return HostDistance .LOCAL
388396
389- if not self .used_hosts_per_remote_dc :
390- return HostDistance .IGNORED
391-
392- dc_hosts = self ._dc_live_hosts .get (dc , ())
393- if not dc_hosts :
394- return HostDistance .IGNORED
395- if host in dc_hosts and dc_hosts .index (host ) < self .used_hosts_per_remote_dc :
397+ if host in self ._remote_hosts :
396398 return HostDistance .REMOTE
397- else :
398- return HostDistance .IGNORED
399+ return HostDistance .IGNORED
399400
400401 def make_query_plan (self , working_keyspace = None , query = None ):
401402 pos = self ._position
402403 self ._position += 1
403404
404405 local_rack_live = self ._live_hosts .get ((self .local_dc , self .local_rack ), ())
405- pos = ( pos % len (local_rack_live )) if local_rack_live else 0
406- # Slice the cyclic iterator to start from pos and include the next len(local_live) elements
407- # This ensures we get exactly one full cycle starting from pos
408- for host in islice (cycle (local_rack_live ), pos , pos + len ( local_rack_live ) ):
409- yield host
406+ length = len (local_rack_live )
407+ if length :
408+ p = pos % length
409+ for host in islice (cycle (local_rack_live ), p , p + length ):
410+ yield host
410411
411- local_live = [host for host in self ._dc_live_hosts .get (self .local_dc , ()) if host .rack != self .local_rack ]
412- pos = (pos % len (local_live )) if local_live else 0
413- for host in islice (cycle (local_live ), pos , pos + len (local_live )):
414- yield host
412+ local_live = self ._dc_live_hosts .get (self .local_dc , ())
413+ local_non_rack = [h for h in local_live if self ._rack (h ) != self .local_rack ]
414+ length = len (local_non_rack )
415+ if length :
416+ p = pos % length
417+ for host in islice (cycle (local_non_rack ), p , p + length ):
418+ yield host
415419
416- # the dict can change, so get candidate DCs iterating over keys of a copy
417- for dc , remote_live in self ._dc_live_hosts .copy ().items ():
418- if dc != self .local_dc :
419- for host in remote_live [:self .used_hosts_per_remote_dc ]:
420- yield host
420+ for host in self ._remote_hosts :
421+ yield host
421422
422423 def on_up (self , host ):
423424 dc = self ._dc (host )
424425 rack = self ._rack (host )
425426 with self ._hosts_lock :
426- current_rack_hosts = self ._live_hosts .get ((dc , rack ), ())
427- if host not in current_rack_hosts :
428- self ._live_hosts [(dc , rack )] = current_rack_hosts + (host , )
429427 current_dc_hosts = self ._dc_live_hosts .get (dc , ())
430428 if host not in current_dc_hosts :
431429 self ._dc_live_hosts [dc ] = current_dc_hosts + (host , )
432430
431+ if dc != self .local_dc :
432+ self ._refresh_remote_hosts ()
433+
434+ current_rack_hosts = self ._live_hosts .get ((dc , rack ), ())
435+ if host not in current_rack_hosts :
436+ self ._live_hosts [(dc , rack )] = current_rack_hosts + (host , )
437+
433438 def on_down (self , host ):
434439 dc = self ._dc (host )
435440 rack = self ._rack (host )
436441 with self ._hosts_lock :
437- current_rack_hosts = self ._live_hosts .get ((dc , rack ), ())
438- if host in current_rack_hosts :
439- hosts = tuple (h for h in current_rack_hosts if h != host )
440- if hosts :
441- self ._live_hosts [(dc , rack )] = hosts
442- else :
443- del self ._live_hosts [(dc , rack )]
444442 current_dc_hosts = self ._dc_live_hosts .get (dc , ())
445443 if host in current_dc_hosts :
446444 hosts = tuple (h for h in current_dc_hosts if h != host )
@@ -449,6 +447,20 @@ def on_down(self, host):
449447 else :
450448 del self ._dc_live_hosts [dc ]
451449
450+ if dc != self .local_dc :
451+ self ._refresh_remote_hosts ()
452+
453+ if dc != self .local_dc :
454+ self ._refresh_remote_hosts ()
455+
456+ current_rack_hosts = self ._live_hosts .get ((dc , rack ), ())
457+ if host in current_rack_hosts :
458+ hosts = tuple (h for h in current_rack_hosts if h != host )
459+ if hosts :
460+ self ._live_hosts [(dc , rack )] = hosts
461+ else :
462+ del self ._live_hosts [(dc , rack )]
463+
452464 def on_add (self , host ):
453465 self .on_up (host )
454466
0 commit comments