From ad29b46d9ec4d88baba21b39ca75808ee8d484f0 Mon Sep 17 00:00:00 2001 From: Mikita Hradovich Date: Thu, 29 Jan 2026 19:20:45 +0100 Subject: [PATCH 1/4] =?UTF-8?q?feat:=20Implement=20LWT=20replica-only=20ro?= =?UTF-8?q?uting=20with=20local=20DC=20prioritization=20in=20`TokenAwarePo?= =?UTF-8?q?licy`.=20=F0=9F=8E=9F=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feat: Enhance LWT query routing by prioritizing local replicas and implementing fallback to child policy feat: Refactor LWT host iterator to preserve replica order and improve host filtering Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Dmitry Kropachev --- README.md | 5 +- .../datastax/driver/core/RequestHandler.java | 45 --- .../core/policies/TokenAwarePolicy.java | 324 +++++++++++++---- .../core/policies/TokenAwarePolicyTest.java | 326 +++++++++++++++++- 4 files changed, 573 insertions(+), 127 deletions(-) diff --git a/README.md b/README.md index bc4ba5ed812..928a48fccf2 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,11 @@ The Scylla Java Driver is a fork from [DataStax Java Driver](https://github.com/ **Features:** * Like all Scylla Drivers, the Scylla Java Driver is **Shard Aware** and contains extensions for a `tokenAwareHostPolicy`. - Using this policy, the driver can select a connection to a particular shard based on the shard’s token. + Using this policy, the driver can select a connection to a particular shard based on the shard's token. As a result, latency is significantly reduced because there is no need to pass data between the shards. +* **Lightweight Transaction (LWT) Optimization**: when using `TokenAwarePolicy` with prepared statements, + LWT queries automatically use replica-only routing, prioritizing local datacenter replicas to minimize + coordinator forwarding overhead and reduce contention during Paxos consensus phases. * [Sync](manual/) and [Async](manual/async/) API * [Simple](manual/statements/simple/), [Prepared](manual/statements/prepared/), and [Batch](manual/statements/batch/) statements diff --git a/driver-core/src/main/java/com/datastax/driver/core/RequestHandler.java b/driver-core/src/main/java/com/datastax/driver/core/RequestHandler.java index 27f6b67d184..8d8a5ae8c71 100644 --- a/driver-core/src/main/java/com/datastax/driver/core/RequestHandler.java +++ b/driver-core/src/main/java/com/datastax/driver/core/RequestHandler.java @@ -97,42 +97,6 @@ class RequestHandler { private final AtomicBoolean isDone = new AtomicBoolean(); private final AtomicInteger executionIndex = new AtomicInteger(); - private Iterator getReplicas( - String loggedKeyspace, Statement statement, Iterator fallback) { - ProtocolVersion protocolVersion = manager.cluster.manager.protocolVersion(); - CodecRegistry codecRegistry = manager.cluster.manager.configuration.getCodecRegistry(); - ByteBuffer partitionKey = statement.getRoutingKey(protocolVersion, codecRegistry); - String keyspace = statement.getKeyspace(); - if (keyspace == null) { - keyspace = loggedKeyspace; - } - - if (partitionKey == null || keyspace == null) { - return fallback; - } - - Token.Factory partitioner = statement.getPartitioner(); - String tableName = null; - ColumnDefinitions defs = null; - if (statement instanceof BoundStatement) { - defs = ((BoundStatement) statement).preparedStatement().getVariables(); - } else if (statement instanceof PreparedStatement) { - defs = ((PreparedStatement) statement).getVariables(); - } - if (defs != null && defs.size() > 0) { - tableName = defs.getTable(0); - } - - final List replicas = - manager - .cluster - .getMetadata() - .getReplicasList(Metadata.quote(keyspace), tableName, partitioner, partitionKey); - - // replicas are stored in the right order starting with the primary replica - return replicas.iterator(); - } - public RequestHandler(SessionManager manager, Callback callback, Statement statement) { this.id = Long.toString(System.identityHashCode(this)); if (logger.isTraceEnabled()) logger.trace("[{}] {}", id, statement); @@ -145,15 +109,6 @@ public RequestHandler(SessionManager manager, Callback callback, Statement state // If host is explicitly set on statement, bypass load balancing policy. if (statement.getHost() != null) { this.queryPlan = new QueryPlan(Iterators.singletonIterator(statement.getHost())); - } else if (statement.isLWT()) { - this.queryPlan = - new QueryPlan( - getReplicas( - manager.poolsState.keyspace, - statement, - manager - .loadBalancingPolicy() - .newQueryPlan(manager.poolsState.keyspace, statement))); } else { this.queryPlan = new QueryPlan( diff --git a/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java b/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java index 3adf4ffd5d5..98e4aeddf0d 100644 --- a/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java +++ b/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java @@ -31,6 +31,7 @@ import com.datastax.driver.core.PreparedStatement; import com.datastax.driver.core.ProtocolVersion; import com.datastax.driver.core.Statement; +import com.datastax.driver.core.Token; import com.google.common.collect.AbstractIterator; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -50,9 +51,10 @@ *
  • the iterator returned by the {@code newQueryPlan} method will first return the {@link * HostDistance#LOCAL LOCAL} replicas for the query if possible (i.e. if the query's * {@linkplain Statement#getRoutingKey(ProtocolVersion, CodecRegistry) routing key} is not - * {@code null} and if the {@linkplain Metadata#getReplicas(String, ByteBuffer) set of - * replicas} for that partition key is not empty). If no local replica can be either found or - * successfully contacted, the rest of the query plan will fallback to the child policy's one. + * {@code null} and if the {@linkplain Metadata#getReplicasList(String, String, Token.Factory, + * ByteBuffer) set of replicas} for that partition key is not empty). If no local replica can + * be either found or successfully contacted, the rest of the query plan will fall back to the + * child policy's one. * * * The exact order in which local replicas are returned is dictated by the {@linkplain @@ -63,6 +65,55 @@ * be considered having priority. For example, if you wrap {@link DCAwareRoundRobinPolicy} with this * token aware policy, replicas from remote data centers may only be returned after all the hosts of * the local data center. + * + *

    Lightweight Transaction (LWT) Routing

    + * + *

    For {@linkplain Statement#isLWT() lightweight transaction} queries, this policy provides + * specialized replica-only routing to optimize LWT performance and avoid contention. When LWT + * routing is enabled (the default), the query plan contains only replicas for the + * target partition, ordered by datacenter locality: + * + *

      + *
    • Local replicas first: replicas for which the child policy reports {@link HostDistance#LOCAL + * LOCAL} distance are returned first, in the order provided by cluster metadata (preserving + * primary replica ordering from the token ring). + *
    • Remote replicas second: remaining replicas (typically in remote datacenters) are appended, + * but only if they are up and not ignored by the child policy. + *
    • Replica-only routing when possible: under normal conditions, LWT query plans target only + * replicas for the partition in order to reduce coordinator forwarding overhead and improve + * performance. When replica information is unavailable, the driver falls back to the child + * policy as described in the fallback behavior below, which may include non-replica hosts. + *
    + * + *

    Rack awareness is intentionally not applied to LWT replica ordering. + * All local replicas are treated equally within the local datacenter to avoid rack-based contention + * hotspots during Paxos consensus phases. + * + *

    Requirements for LWT replica-only routing: + * + *

      + *
    • The statement's {@linkplain Statement#getRoutingKey(ProtocolVersion, CodecRegistry) routing + * key} must be available (use {@linkplain PreparedStatement prepared statements} or manually + * set the routing key). + *
    • The effective keyspace must be known (set on the statement or session). + *
    • Cluster metadata must contain replica information for the target partition. + *
    • A child policy that correctly reports datacenter locality via {@link + * LoadBalancingPolicy#distance(Host)} (e.g., {@link DCAwareRoundRobinPolicy}) must be + * configured. + *
    + * + *

    Fallback behavior: If routing key, keyspace, or replica metadata is + * unavailable, the driver falls back to the child policy's normal query plan. In this case, the + * query plan may include non-replica hosts, and LWT may incur additional coordinator forwarding + * latency. This fallback is a pragmatic safety net to preserve availability when routing + * information is incomplete. + * + *

    LWT routing can be configured via {@link + * QueryOptions#setLoadBalancingLwtRequestRoutingMethod(QueryOptions.RequestRoutingMethod)}. The + * default preserves replica order for optimal LWT performance. + * + * @see DCAwareRoundRobinPolicy + * @see QueryOptions.RequestRoutingMethod */ public class TokenAwarePolicy implements ChainableLoadBalancingPolicy { @@ -100,11 +151,188 @@ public enum ReplicaOrdering { NEUTRAL } + /** + * An iterator that returns local replicas first (in the order provided by the child policy), then + * the remaining hosts. + */ + private class NeutralHostIterator extends AbstractIterator { + + private final Iterator childIterator; + private final List replicas; + private List nonReplicas; + private Iterator nonReplicasIterator; + + public NeutralHostIterator(Iterator childIterator, List replicas) { + this.childIterator = childIterator; + this.replicas = replicas; + } + + @Override + protected Host computeNext() { + + while (childIterator.hasNext()) { + + Host host = childIterator.next(); + + if (host.isUp() + && replicas.contains(host) + && childPolicy.distance(host) == HostDistance.LOCAL) { + // UP replicas should be prioritized, retaining order from childPolicy + return host; + } else { + // save for later + if (nonReplicas == null) nonReplicas = new ArrayList<>(); + nonReplicas.add(host); + } + } + + // This should only engage if all local replicas are DOWN + if (nonReplicas != null) { + + if (nonReplicasIterator == null) nonReplicasIterator = nonReplicas.iterator(); + + if (nonReplicasIterator.hasNext()) return nonReplicasIterator.next(); + } + + return endOfData(); + } + } + + /** + * An iterator that returns local replicas first (in either random or topological order, as + * specified at instantiation), then the remaining hosts. + */ + private class RandomOrTopologicalHostIterator extends AbstractIterator { + + private final Iterator replicasIterator; + private final String keyspace; + private final Statement statement; + private final List replicas; + private Iterator childIterator; + + public RandomOrTopologicalHostIterator( + String keyspace, + Statement statement, + Iterator replicasIterator, + List replicas) { + this.replicasIterator = replicasIterator; + this.keyspace = keyspace; + this.statement = statement; + this.replicas = replicas; + } + + @Override + protected Host computeNext() { + while (replicasIterator.hasNext()) { + Host host = replicasIterator.next(); + if (host.isUp() && childPolicy.distance(host) == HostDistance.LOCAL) return host; + } + + if (childIterator == null) childIterator = childPolicy.newQueryPlan(keyspace, statement); + + while (childIterator.hasNext()) { + Host host = childIterator.next(); + // Skip it if it was already a local replica + if (!replicas.contains(host) || childPolicy.distance(host) != HostDistance.LOCAL) + return host; + } + return endOfData(); + } + } + + /** + * An iterator that returns replicas first, with local replicas prioritized (preserving primary + * replica order), then remote replicas. Used for LWT queries to ensure replica-only routing and + * minimize coordinator forwarding overhead. DOWN and IGNORED hosts are filtered out. + * + *

    Query plan follows a three-pass strategy: + * + *

      + *
    1. Local replicas: Returns UP replicas marked as LOCAL by the child policy, + * in the order provided by cluster metadata (preserving primary replica order). + *
    2. Remote replicas: Returns UP replicas marked as REMOTE by the child + * policy. + *
    3. Child policy fallback: If no suitable replicas are available (for + * example, all are DOWN or IGNORED and thus none are returned), falls back to the child + * policy's query plan for the remaining hosts. The child policy's plan is used as-is and + * may include hosts that were already considered by this iterator. + *
    + */ + private class PreserveReplicaOrderIterator extends AbstractIterator { + private final Iterator replicasIterator; + private final String keyspace; + private final Statement statement; + private List nonLocalReplicas; + private Iterator nonLocalReplicasIterator; + private boolean hasReturnedReplicas; + private Iterator childIterator; + + public PreserveReplicaOrderIterator( + String keyspace, Statement statement, Iterator replicasIterator) { + this.keyspace = keyspace; + this.statement = statement; + this.replicasIterator = replicasIterator; + } + + @Override + protected Host computeNext() { + // First pass: return local replicas that are UP + while (replicasIterator.hasNext()) { + Host host = replicasIterator.next(); + HostDistance distance = childPolicy.distance(host); + + if (!host.isUp()) { + // Skip DOWN hosts entirely + continue; + } + + switch (distance) { + case LOCAL: + hasReturnedReplicas = true; + return host; + case REMOTE: + // Collect remote replicas for second pass + if (nonLocalReplicas == null) nonLocalReplicas = new ArrayList<>(); + nonLocalReplicas.add(host); + break; + case IGNORED: // Skip IGNORED hosts entirely + default: // For safety, treat any unexpected distance as IGNORED + break; + } + } + + // Second pass: return remote replicas that are UP and not IGNORED + if (nonLocalReplicas != null) { + if (nonLocalReplicasIterator == null) { + nonLocalReplicasIterator = nonLocalReplicas.iterator(); + } + if (nonLocalReplicasIterator.hasNext()) { + hasReturnedReplicas = true; + return nonLocalReplicasIterator.next(); + } + } + + // Third pass: fallback to child policy if no suitable replicas were returned + // This handles cases where all replicas are empty, DOWN or IGNORED + if (!hasReturnedReplicas) { + if (childIterator == null) { + childIterator = childPolicy.newQueryPlan(keyspace, statement); + } + if (childIterator.hasNext()) { + return childIterator.next(); + } + } + + return endOfData(); + } + } + private final LoadBalancingPolicy childPolicy; private final ReplicaOrdering replicaOrdering; private volatile Metadata clusterMetadata; private volatile ProtocolVersion protocolVersion; private volatile CodecRegistry codecRegistry; + private volatile QueryOptions.RequestRoutingMethod defaultLwtRequestRoutingMethod; /** * Creates a new {@code TokenAware} policy. @@ -127,7 +355,6 @@ public TokenAwarePolicy(LoadBalancingPolicy childPolicy, ReplicaOrdering replica * @deprecated Use {@link #TokenAwarePolicy(LoadBalancingPolicy, ReplicaOrdering)} instead. This * constructor will be removed in the next major release. */ - @SuppressWarnings("DeprecatedIsStillUsed") @Deprecated public TokenAwarePolicy(LoadBalancingPolicy childPolicy, boolean shuffleReplicas) { this(childPolicy, shuffleReplicas ? ReplicaOrdering.RANDOM : ReplicaOrdering.TOPOLOGICAL); @@ -153,6 +380,8 @@ public void init(Cluster cluster, Collection hosts) { clusterMetadata = cluster.getMetadata(); protocolVersion = cluster.getConfiguration().getProtocolOptions().getProtocolVersion(); codecRegistry = cluster.getConfiguration().getCodecRegistry(); + defaultLwtRequestRoutingMethod = + cluster.getConfiguration().getQueryOptions().getLoadBalancingLwtRequestRoutingMethod(); childPolicy.init(cluster, hosts); } @@ -179,7 +408,6 @@ public HostDistance distance(Host host) { */ @Override public Iterator newQueryPlan(final String loggedKeyspace, final Statement statement) { - ByteBuffer partitionKey = statement.getRoutingKey(protocolVersion, codecRegistry); String keyspace = statement.getKeyspace(); if (keyspace == null) keyspace = loggedKeyspace; @@ -201,47 +429,34 @@ public Iterator newQueryPlan(final String loggedKeyspace, final Statement final List replicas = clusterMetadata.getReplicasList( Metadata.quote(keyspace), tableName, statement.getPartitioner(), partitionKey); - if (replicas.isEmpty()) return childPolicy.newQueryPlan(loggedKeyspace, statement); - - if (replicaOrdering == ReplicaOrdering.NEUTRAL) { - final Iterator childIterator = childPolicy.newQueryPlan(keyspace, statement); - - return new AbstractIterator() { - - private List nonReplicas; - private Iterator nonReplicasIterator; - - @Override - protected Host computeNext() { - - while (childIterator.hasNext()) { - - Host host = childIterator.next(); + switch (getRequestRouting(statement)) { + case PRESERVE_REPLICA_ORDER: + return newQueryPlanPreserveReplicaOrder(keyspace, statement, replicas); + case REGULAR: + default: + return newQueryPlanRegular(keyspace, statement, replicas); + } + } - if (host.isUp() - && replicas.contains(host) - && childPolicy.distance(host) == HostDistance.LOCAL) { - // UP replicas should be prioritized, retaining order from childPolicy - return host; - } else { - // save for later - if (nonReplicas == null) nonReplicas = new ArrayList(); - nonReplicas.add(host); - } - } + private QueryOptions.RequestRoutingMethod getRequestRouting(Statement statement) { + if (!statement.isLWT() || defaultLwtRequestRoutingMethod == null) { + return QueryOptions.RequestRoutingMethod.REGULAR; + } + return defaultLwtRequestRoutingMethod; + } - // This should only engage if all local replicas are DOWN - if (nonReplicas != null) { + private Iterator newQueryPlanRegular( + String keyspace, Statement statement, List replicas) { + if (replicas.isEmpty()) { + return childPolicy.newQueryPlan(keyspace, statement); + } - if (nonReplicasIterator == null) nonReplicasIterator = nonReplicas.iterator(); + if (replicaOrdering == ReplicaOrdering.NEUTRAL) { - if (nonReplicasIterator.hasNext()) return nonReplicasIterator.next(); - } + final Iterator childIterator = childPolicy.newQueryPlan(keyspace, statement); - return endOfData(); - } - }; + return new NeutralHostIterator(childIterator, replicas); } else { @@ -255,32 +470,15 @@ protected Host computeNext() { replicasIterator = replicas.iterator(); } - return new AbstractIterator() { - - private Iterator childIterator; - - @Override - protected Host computeNext() { - while (replicasIterator.hasNext()) { - Host host = replicasIterator.next(); - if (host.isUp() && childPolicy.distance(host) == HostDistance.LOCAL) return host; - } - - if (childIterator == null) - childIterator = childPolicy.newQueryPlan(loggedKeyspace, statement); - - while (childIterator.hasNext()) { - Host host = childIterator.next(); - // Skip it if it was already a local replica - if (!replicas.contains(host) || childPolicy.distance(host) != HostDistance.LOCAL) - return host; - } - return endOfData(); - } - }; + return new RandomOrTopologicalHostIterator(keyspace, statement, replicasIterator, replicas); } } + private Iterator newQueryPlanPreserveReplicaOrder( + String keyspace, Statement statement, List replicas) { + return new PreserveReplicaOrderIterator(keyspace, statement, replicas.iterator()); + } + @Override public void onUp(Host host) { childPolicy.onUp(host); diff --git a/driver-core/src/test/java/com/datastax/driver/core/policies/TokenAwarePolicyTest.java b/driver-core/src/test/java/com/datastax/driver/core/policies/TokenAwarePolicyTest.java index 8f780a8bef4..80a0dd66ff5 100644 --- a/driver-core/src/test/java/com/datastax/driver/core/policies/TokenAwarePolicyTest.java +++ b/driver-core/src/test/java/com/datastax/driver/core/policies/TokenAwarePolicyTest.java @@ -42,6 +42,7 @@ import com.datastax.driver.core.PreparedStatement; import com.datastax.driver.core.ProtocolOptions; import com.datastax.driver.core.ProtocolVersion; +import com.datastax.driver.core.QueryOptions; import com.datastax.driver.core.QueryTracker; import com.datastax.driver.core.RegularStatement; import com.datastax.driver.core.ResultSet; @@ -50,6 +51,7 @@ import com.datastax.driver.core.Session; import com.datastax.driver.core.SimpleStatement; import com.datastax.driver.core.SortingLoadBalancingPolicy; +import com.datastax.driver.core.Statement; import com.datastax.driver.core.TestUtils; import com.datastax.driver.core.TypeCodec; import com.google.common.collect.ImmutableMap; @@ -64,6 +66,7 @@ public class TokenAwarePolicyTest { + public static final String KEYSPACE = "keyspace"; private ByteBuffer routingKey = ByteBuffer.wrap(new byte[] {1, 2, 3, 4}); private RegularStatement statement = new SimpleStatement("irrelevant").setRoutingKey(routingKey); @@ -75,6 +78,8 @@ public class TokenAwarePolicyTest { private LoadBalancingPolicy childPolicy; private Cluster cluster; + private Metadata metadata; + private QueryOptions queryOptions; @BeforeMethod(groups = "unit") public void initMocks() { @@ -82,18 +87,22 @@ public void initMocks() { cluster = mock(Cluster.class); Configuration configuration = mock(Configuration.class); ProtocolOptions protocolOptions = mock(ProtocolOptions.class); - Metadata metadata = mock(Metadata.class); + queryOptions = mock(QueryOptions.class); + when(queryOptions.getLoadBalancingLwtRequestRoutingMethod()) + .thenReturn(QueryOptions.RequestRoutingMethod.PRESERVE_REPLICA_ORDER); + metadata = mock(Metadata.class); childPolicy = mock(LoadBalancingPolicy.class); when(cluster.getConfiguration()).thenReturn(configuration); when(configuration.getCodecRegistry()).thenReturn(codecRegistry); when(configuration.getProtocolOptions()).thenReturn(protocolOptions); + when(configuration.getQueryOptions()).thenReturn(queryOptions); when(protocolOptions.getProtocolVersion()).thenReturn(ProtocolVersion.DEFAULT); when(cluster.getMetadata()).thenReturn(metadata); - when(metadata.getReplicas(Metadata.quote("keyspace"), null, null, routingKey)) + when(metadata.getReplicas(Metadata.quote(KEYSPACE), null, null, routingKey)) .thenReturn(Sets.newLinkedHashSet(host1, host2)); - when(metadata.getReplicasList(Metadata.quote("keyspace"), null, null, routingKey)) + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), null, null, routingKey)) .thenReturn(Lists.newArrayList(host1, host2)); - when(childPolicy.newQueryPlan("keyspace", statement)) + when(childPolicy.newQueryPlan(KEYSPACE, statement)) .thenReturn(Sets.newLinkedHashSet(host4, host3, host2, host1).iterator()); when(childPolicy.distance(any(Host.class))).thenReturn(HostDistance.LOCAL); when(host1.isUp()).thenReturn(true); @@ -117,7 +126,7 @@ public void should_respect_topological_order() { TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, TOPOLOGICAL); policy.init(cluster, null); // when - Iterator queryPlan = policy.newQueryPlan("keyspace", statement); + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, statement); // then assertThat(queryPlan).containsExactly(host1, host2, host4, host3); } @@ -128,7 +137,7 @@ public void should_respect_child_policy_order() { TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, NEUTRAL); policy.init(cluster, null); // when - Iterator queryPlan = policy.newQueryPlan("keyspace", statement); + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, statement); // then assertThat(queryPlan).containsExactly(host2, host1, host4, host3); } @@ -139,11 +148,291 @@ public void should_create_random_order() { TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, RANDOM); policy.init(cluster, null); // when - Iterator queryPlan = policy.newQueryPlan("keyspace", statement); + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, statement); // then assertThat(queryPlan).containsOnlyOnce(host1, host2, host3, host4).endsWith(host4, host3); } + @Test(groups = "unit", dataProvider = "shuffleProvider") + public void should_prioritize_local_replicas_for_lwt(TokenAwarePolicy.ReplicaOrdering ordering) { + // given + Statement lwtStatement = mock(Statement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRoutingKey(any(ProtocolVersion.class), any(CodecRegistry.class))) + .thenReturn(routingKey); + when(lwtStatement.getKeyspace()).thenReturn(KEYSPACE); + when(childPolicy.distance(host1)).thenReturn(HostDistance.REMOTE); + when(childPolicy.distance(host2)).thenReturn(HostDistance.LOCAL); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, ordering); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, lwtStatement); + + // then: local replica first, then remaining replicas only + assertThat(queryPlan).containsExactly(host2, host1); + } + + @Test(groups = "unit", dataProvider = "shuffleProvider") + public void should_preserve_replica_order_for_lwt(TokenAwarePolicy.ReplicaOrdering ordering) { + // given + Statement lwtStatement = mock(Statement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRoutingKey(any(ProtocolVersion.class), any(CodecRegistry.class))) + .thenReturn(routingKey); + when(lwtStatement.getKeyspace()).thenReturn(KEYSPACE); + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), null, null, routingKey)) + .thenReturn(Lists.newArrayList(host2, host3, host1)); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, ordering); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, lwtStatement); + + // then: replica order preserved and only replicas returned + assertThat(queryPlan).containsExactly(host2, host3, host1); + } + + @Test(groups = "unit") + public void should_fallback_to_child_policy_for_lwt_when_no_replicas() { + // given + Statement lwtStatement = mock(Statement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRoutingKey(any(ProtocolVersion.class), any(CodecRegistry.class))) + .thenReturn(routingKey); + when(lwtStatement.getKeyspace()).thenReturn(KEYSPACE); + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), null, null, routingKey)) + .thenReturn(Lists.newArrayList()); + when(childPolicy.newQueryPlan(KEYSPACE, lwtStatement)) + .thenReturn(Sets.newLinkedHashSet(host4, host3, host2, host1).iterator()); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, TOPOLOGICAL); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, lwtStatement); + + // then: fallback to child policy plan + assertThat(queryPlan).containsExactly(host4, host3, host2, host1); + } + + @Test(groups = "unit", dataProvider = "shuffleProvider") + public void should_filter_down_replicas_for_lwt(TokenAwarePolicy.ReplicaOrdering ordering) { + // given: LWT statement with some replicas DOWN + Statement lwtStatement = mock(Statement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRoutingKey(any(ProtocolVersion.class), any(CodecRegistry.class))) + .thenReturn(routingKey); + when(lwtStatement.getKeyspace()).thenReturn(KEYSPACE); + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), null, null, routingKey)) + .thenReturn(Lists.newArrayList(host1, host2, host3)); + + // host1 is LOCAL but DOWN + when(childPolicy.distance(host1)).thenReturn(HostDistance.LOCAL); + when(host1.isUp()).thenReturn(false); + + // host2 is LOCAL and UP + when(childPolicy.distance(host2)).thenReturn(HostDistance.LOCAL); + when(host2.isUp()).thenReturn(true); + + // host3 is REMOTE but DOWN + when(childPolicy.distance(host3)).thenReturn(HostDistance.REMOTE); + when(host3.isUp()).thenReturn(false); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, ordering); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, lwtStatement); + + // then: only UP replicas are returned (host1 and host3 are DOWN so excluded) + assertThat(queryPlan).containsExactly(host2); + } + + @Test(groups = "unit", dataProvider = "shuffleProvider") + public void should_filter_ignored_replicas_for_lwt(TokenAwarePolicy.ReplicaOrdering ordering) { + // given: LWT statement with some replicas IGNORED + Statement lwtStatement = mock(Statement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRoutingKey(any(ProtocolVersion.class), any(CodecRegistry.class))) + .thenReturn(routingKey); + when(lwtStatement.getKeyspace()).thenReturn(KEYSPACE); + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), null, null, routingKey)) + .thenReturn(Lists.newArrayList(host1, host2, host3)); + + // host1 is LOCAL and UP + when(childPolicy.distance(host1)).thenReturn(HostDistance.LOCAL); + when(host1.isUp()).thenReturn(true); + + // host2 is IGNORED (e.g., filtered by allowlist) + when(childPolicy.distance(host2)).thenReturn(HostDistance.IGNORED); + when(host2.isUp()).thenReturn(true); + + // host3 is REMOTE and UP + when(childPolicy.distance(host3)).thenReturn(HostDistance.REMOTE); + when(host3.isUp()).thenReturn(true); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, ordering); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, lwtStatement); + + // then: IGNORED replicas are excluded (host2), local first then remote + assertThat(queryPlan).containsExactly(host1, host3); + } + + @Test(groups = "unit", dataProvider = "shuffleProvider") + public void should_filter_down_and_ignored_replicas_for_lwt( + TokenAwarePolicy.ReplicaOrdering ordering) { + // given: LWT statement with mixed replica states + Statement lwtStatement = mock(Statement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRoutingKey(any(ProtocolVersion.class), any(CodecRegistry.class))) + .thenReturn(routingKey); + when(lwtStatement.getKeyspace()).thenReturn(KEYSPACE); + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), null, null, routingKey)) + .thenReturn(Lists.newArrayList(host1, host2, host3, host4)); + + // host1 is LOCAL and UP + when(childPolicy.distance(host1)).thenReturn(HostDistance.LOCAL); + when(host1.isUp()).thenReturn(true); + + // host2 is LOCAL but DOWN + when(childPolicy.distance(host2)).thenReturn(HostDistance.LOCAL); + when(host2.isUp()).thenReturn(false); + + // host3 is REMOTE but IGNORED + when(childPolicy.distance(host3)).thenReturn(HostDistance.IGNORED); + when(host3.isUp()).thenReturn(true); + + // host4 is REMOTE and UP + when(childPolicy.distance(host4)).thenReturn(HostDistance.REMOTE); + when(host4.isUp()).thenReturn(true); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, ordering); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, lwtStatement); + + // then: only UP and non-IGNORED replicas, local first + assertThat(queryPlan).containsExactly(host1, host4); + } + + /** + * Given an LWT statement where all replicas are either DOWN or IGNORED, ensures that the returned + * query plan falls back to the child policy. + * + * @param ordering the replica ordering to use in the TokenAwarePolicy + */ + @Test(groups = "unit", dataProvider = "shuffleProvider") + public void should_fallback_to_child_when_all_lwt_replicas_filtered( + TokenAwarePolicy.ReplicaOrdering ordering) { + // given: LWT statement where all replicas are DOWN or IGNORED + Statement lwtStatement = mock(Statement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRoutingKey(any(ProtocolVersion.class), any(CodecRegistry.class))) + .thenReturn(routingKey); + when(lwtStatement.getKeyspace()).thenReturn(KEYSPACE); + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), null, null, routingKey)) + .thenReturn(Lists.newArrayList(host1, host2)); + + // host1 is DOWN + when(childPolicy.distance(host1)).thenReturn(HostDistance.LOCAL); + when(host1.isUp()).thenReturn(false); + + // host2 is IGNORED + when(childPolicy.distance(host2)).thenReturn(HostDistance.IGNORED); + when(host2.isUp()).thenReturn(true); + + // hosts 3 & 4 are non-replicas and can be down + when(childPolicy.distance(host3)).thenReturn(HostDistance.REMOTE); + when(host3.isUp()).thenReturn(true); + when(childPolicy.distance(host4)).thenReturn(HostDistance.REMOTE); + when(host4.isUp()).thenReturn(false); + + // Mock child policy to return available hosts + when(childPolicy.newQueryPlan(KEYSPACE, lwtStatement)) + .thenReturn(Sets.newLinkedHashSet(host3, host4).iterator()); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, ordering); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, lwtStatement); + + // then: fallback to child policy plan (all replicas filtered out) + assertThat(queryPlan).containsExactly(host3, host4); + } + + @Test(groups = "unit", dataProvider = "shuffleProvider") + public void should_return_all_local_replicas_when_all_replicas_are_local( + TokenAwarePolicy.ReplicaOrdering ordering) { + // given: LWT statement where all replicas are LOCAL and UP (edge case for NPE guard) + Statement lwtStatement = mock(Statement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRoutingKey(any(ProtocolVersion.class), any(CodecRegistry.class))) + .thenReturn(routingKey); + when(lwtStatement.getKeyspace()).thenReturn(KEYSPACE); + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), null, null, routingKey)) + .thenReturn(Lists.newArrayList(host1, host2, host3)); + + // All replicas are LOCAL and UP (no non-local replicas to collect) + when(childPolicy.distance(host1)).thenReturn(HostDistance.LOCAL); + when(host1.isUp()).thenReturn(true); + + when(childPolicy.distance(host2)).thenReturn(HostDistance.LOCAL); + when(host2.isUp()).thenReturn(true); + + when(childPolicy.distance(host3)).thenReturn(HostDistance.LOCAL); + when(host3.isUp()).thenReturn(true); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, ordering); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, lwtStatement); + + // then: should return all local replicas without NPE (nonLocalReplicas remains null) + assertThat(queryPlan).containsExactly(host1, host2, host3); + } + + @Test(groups = "unit", dataProvider = "shuffleProvider") + public void should_allow_child_policy_to_retry_down_replicas_in_fallback( + TokenAwarePolicy.ReplicaOrdering ordering) { + // given: LWT statement where all replicas are DOWN + Statement lwtStatement = mock(Statement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRoutingKey(any(ProtocolVersion.class), any(CodecRegistry.class))) + .thenReturn(routingKey); + when(lwtStatement.getKeyspace()).thenReturn(KEYSPACE); + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), null, null, routingKey)) + .thenReturn(Lists.newArrayList(host1, host2)); + + // Both replicas are DOWN + when(childPolicy.distance(host1)).thenReturn(HostDistance.LOCAL); + when(host1.isUp()).thenReturn(false); + when(childPolicy.distance(host2)).thenReturn(HostDistance.REMOTE); + when(host2.isUp()).thenReturn(false); + + // Child policy includes the DOWN replicas in its plan (it may have different logic) + when(childPolicy.newQueryPlan(KEYSPACE, lwtStatement)) + .thenReturn(Lists.newArrayList(host1, host2, host3, host4).iterator()); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, ordering); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, lwtStatement); + + // then: fallback to child policy, which can include even the DOWN replicas + // (no filtering, child policy decides) + assertThat(queryPlan).containsExactly(host1, host2, host3, host4); + } + /** * Ensures that {@link TokenAwarePolicy} will shuffle discovered replicas depending on the value * of shuffleReplicas used when constructing with {@link @@ -159,7 +448,7 @@ public void should_order_replicas_based_on_configuration( // given: an 8 node cluster using TokenAwarePolicy and some shuffle replica configuration with a // keyspace with replication factor of 3. ScassandraCluster sCluster = - ScassandraCluster.builder().withNodes(8).withSimpleKeyspace("keyspace", 3).build(); + ScassandraCluster.builder().withNodes(8).withSimpleKeyspace(KEYSPACE, 3).build(); LoadBalancingPolicy loadBalancingPolicy = new TokenAwarePolicy(new SortingLoadBalancingPolicy(), ordering); @@ -184,7 +473,7 @@ public void should_order_replicas_based_on_configuration( // then: The replicas resolved from the cluster metadata must match node 6 and its replicas. List replicas = - Lists.newArrayList(cluster.getMetadata().getReplicas("keyspace", null, routingKey)); + Lists.newArrayList(cluster.getMetadata().getReplicas(KEYSPACE, null, routingKey)); assertThat(replicas) .containsExactly( sCluster.host(cluster, 1, 6), @@ -196,7 +485,7 @@ public void should_order_replicas_based_on_configuration( // Actual query does not matter, only the keyspace and routing key will be used SimpleStatement statement = new SimpleStatement("select * from table where k=5"); statement.setRoutingKey(routingKey); - statement.setKeyspace("keyspace"); + statement.setKeyspace(KEYSPACE); List queryPlan = Lists.newArrayList(loadBalancingPolicy.newQueryPlan(null, statement)); assertThat(queryPlan).containsOnlyElementsOf(cluster.getMetadata().getAllHosts()); @@ -227,7 +516,7 @@ public void should_order_replicas_based_on_configuration( public void should_choose_proper_host_based_on_routing_key() { // given: A 3 node cluster using TokenAwarePolicy with a replication factor of 1. ScassandraCluster sCluster = - ScassandraCluster.builder().withNodes(3).withSimpleKeyspace("keyspace", 1).build(); + ScassandraCluster.builder().withNodes(3).withSimpleKeyspace(KEYSPACE, 1).build(); Cluster cluster = Cluster.builder() .addContactPoints(sCluster.address(1).getAddress()) @@ -249,7 +538,7 @@ public void should_choose_proper_host_based_on_routing_key() { SimpleStatement statement = new SimpleStatement("select * from table where k=5") .setRoutingKey(routingKey) - .setKeyspace("keyspace"); + .setKeyspace(KEYSPACE); QueryTracker queryTracker = new QueryTracker(); queryTracker.query(session, 10, statement); @@ -278,7 +567,7 @@ public void should_choose_host_in_local_dc_when_using_network_topology_strategy_ ScassandraCluster sCluster = ScassandraCluster.builder() .withNodes(3, 3) - .withNetworkTopologyKeyspace("keyspace", ImmutableMap.of(1, 1, 2, 1)) + .withNetworkTopologyKeyspace(KEYSPACE, ImmutableMap.of(1, 1, 2, 1)) .build(); @SuppressWarnings("deprecation") Cluster cluster = @@ -310,7 +599,7 @@ public void should_choose_host_in_local_dc_when_using_network_topology_strategy_ SimpleStatement statement = new SimpleStatement("select * from table where k=5") .setRoutingKey(routingKey) - .setKeyspace("keyspace"); + .setKeyspace(KEYSPACE); QueryTracker queryTracker = new QueryTracker(); queryTracker.query(session, 10, statement); @@ -335,7 +624,7 @@ public void should_choose_host_in_local_dc_when_using_network_topology_strategy_ public void should_use_other_nodes_when_replicas_having_token_are_down() { // given: A 4 node cluster using TokenAwarePolicy with a replication factor of 2. ScassandraCluster sCluster = - ScassandraCluster.builder().withNodes(4).withSimpleKeyspace("keyspace", 2).build(); + ScassandraCluster.builder().withNodes(4).withSimpleKeyspace(KEYSPACE, 2).build(); Cluster cluster = Cluster.builder() .addContactPoints(sCluster.address(2).getAddress()) @@ -361,7 +650,7 @@ public void should_use_other_nodes_when_replicas_having_token_are_down() { SimpleStatement statement = new SimpleStatement("select * from table where k=5") .setRoutingKey(routingKey) - .setKeyspace("keyspace"); + .setKeyspace(KEYSPACE); QueryTracker queryTracker = new QueryTracker(); queryTracker.query(session, 10, statement); @@ -435,7 +724,7 @@ public void should_use_other_nodes_when_replicas_having_token_are_down() { public void should_use_provided_routing_key_boundstatement() { // given: A 4 node cluster using TokenAwarePolicy with a replication factor of 1. ScassandraCluster sCluster = - ScassandraCluster.builder().withNodes(4).withSimpleKeyspace("keyspace", 1).build(); + ScassandraCluster.builder().withNodes(4).withSimpleKeyspace(KEYSPACE, 1).build(); Cluster cluster = Cluster.builder() .addContactPoints(sCluster.address(2).getAddress()) @@ -449,7 +738,7 @@ public void should_use_provided_routing_key_boundstatement() { try { sCluster.init(); - Session session = cluster.connect("keyspace"); + Session session = cluster.connect(KEYSPACE); PreparedStatement preparedStatement = session.prepare("insert into tbl (k0, v) values (?, ?)"); @@ -546,6 +835,7 @@ public void should_properly_generate_and_use_routing_key_for_composite_partition assertThat(rs.getExecutionInfo().getQueriedHost()).isEqualTo(host1); assertThat(rs.isExhausted()).isFalse(); Row r = rs.one(); + assertThat(rs.getExecutionInfo().getQueriedHost()).isEqualTo(host1); assertThat(rs.isExhausted()).isTrue(); assertThat(r.getInt("i")).isEqualTo(3); From 55a8336826332398b45125a6e8fc077b8ef5d3f8 Mon Sep 17 00:00:00 2001 From: Mikita Hradovich Date: Thu, 29 Jan 2026 21:30:26 +0100 Subject: [PATCH 2/4] =?UTF-8?q?feat:=20Add=20configurable=20load=20balanci?= =?UTF-8?q?ng=20request=20routing=20method=20for=20LWT=20queries.=20?= =?UTF-8?q?=E2=9A=99=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../datastax/driver/core/QueryOptions.java | 41 ++++++++++++++++++- .../core/policies/TokenAwarePolicy.java | 1 + 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/driver-core/src/main/java/com/datastax/driver/core/QueryOptions.java b/driver-core/src/main/java/com/datastax/driver/core/QueryOptions.java index da3a4b0aaee..ee55dfcf381 100644 --- a/driver-core/src/main/java/com/datastax/driver/core/QueryOptions.java +++ b/driver-core/src/main/java/com/datastax/driver/core/QueryOptions.java @@ -50,6 +50,9 @@ public class QueryOptions { public static final int DEFAULT_REFRESH_SCHEMA_INTERVAL_MILLIS = 1000; + public static final RequestRoutingMethod DEFAULT_LOAD_BALANCING_LWT_REQUEST_ROUTING_METHOD = + RequestRoutingMethod.PRESERVE_REPLICA_ORDER; + private volatile ConsistencyLevel consistency = DEFAULT_CONSISTENCY_LEVEL; private volatile ConsistencyLevel serialConsistency = DEFAULT_SERIAL_CONSISTENCY_LEVEL; private volatile int fetchSize = DEFAULT_FETCH_SIZE; @@ -79,6 +82,9 @@ public class QueryOptions { private volatile boolean addOriginalContactsToReconnectionPlan = false; private volatile boolean considerZeroTokenNodesValidPeers = false; + private volatile RequestRoutingMethod loadBalancingLwtRequestRoutingMethod = + DEFAULT_LOAD_BALANCING_LWT_REQUEST_ROUTING_METHOD; + /** * Creates a new {@link QueryOptions} instance using the {@link #DEFAULT_CONSISTENCY_LEVEL}, * {@link #DEFAULT_SERIAL_CONSISTENCY_LEVEL} and {@link #DEFAULT_FETCH_SIZE}. @@ -221,7 +227,7 @@ public QueryOptions setSkipCQL4MetadataResolveMethod(CQL4SkipMetadataResolveMeth /** * Skip metadata resolve method . * - *

    It defaults to {@link #skipCQL4MetadataResolveMethod.SMART}. + *

    It defaults to {@link CQL4SkipMetadataResolveMethod#SMART}. * * @return the default idempotence for queries. */ @@ -574,6 +580,28 @@ public boolean shouldConsiderZeroTokenNodesValidPeers() { return this.considerZeroTokenNodesValidPeers; } + /** + * Sets the default request routing method to use for LWT queries. Default is {@link + * RequestRoutingMethod#PRESERVE_REPLICA_ORDER}. + * + * @param loadBalancingLwtRequestRoutingMethod the new request routing method. + * @return this {@code QueryOptions} instance. + */ + public QueryOptions setLoadBalancingLwtRequestRoutingMethod( + RequestRoutingMethod loadBalancingLwtRequestRoutingMethod) { + this.loadBalancingLwtRequestRoutingMethod = loadBalancingLwtRequestRoutingMethod; + return this; + } + + /** + * The default request routing method used by LWT queries. + * + * @return the default request routing method used by LWT queries. + */ + public RequestRoutingMethod getLoadBalancingLwtRequestRoutingMethod() { + return loadBalancingLwtRequestRoutingMethod; + } + @Override public boolean equals(Object that) { if (that == null || !(that instanceof QueryOptions)) { @@ -594,7 +622,9 @@ public boolean equals(Object that) { && this.refreshNodeIntervalMillis == other.refreshNodeIntervalMillis && this.refreshSchemaIntervalMillis == other.refreshSchemaIntervalMillis && this.reprepareOnUp == other.reprepareOnUp - && this.prepareOnAllHosts == other.prepareOnAllHosts) + && this.prepareOnAllHosts == other.prepareOnAllHosts + && this.loadBalancingLwtRequestRoutingMethod + == other.loadBalancingLwtRequestRoutingMethod) && this.schemaQueriesPaged == other.schemaQueriesPaged; } @@ -614,6 +644,7 @@ public int hashCode() { refreshSchemaIntervalMillis, reprepareOnUp, prepareOnAllHosts, + loadBalancingLwtRequestRoutingMethod, schemaQueriesPaged); } @@ -626,4 +657,10 @@ public enum CQL4SkipMetadataResolveMethod { DISABLED, SMART } + + /** The request routing method for queries. */ + public enum RequestRoutingMethod { + REGULAR, + PRESERVE_REPLICA_ORDER + } } diff --git a/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java b/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java index 98e4aeddf0d..27022c235c8 100644 --- a/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java +++ b/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java @@ -30,6 +30,7 @@ import com.datastax.driver.core.Metadata; import com.datastax.driver.core.PreparedStatement; import com.datastax.driver.core.ProtocolVersion; +import com.datastax.driver.core.QueryOptions; import com.datastax.driver.core.Statement; import com.datastax.driver.core.Token; import com.google.common.collect.AbstractIterator; From 02c3773078307bcd12bdb38e1d4f82829fae7ceb Mon Sep 17 00:00:00 2001 From: Mikita Hradovich Date: Mon, 2 Feb 2026 21:35:42 +0100 Subject: [PATCH 3/4] =?UTF-8?q?feat:=20Implement=20LWT=20routing=20optimiz?= =?UTF-8?q?ation=20in=20`RackAwareRoundRobinPolicy`.=20=F0=9F=96=A5?= =?UTF-8?q?=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../policies/RackAwareRoundRobinPolicy.java | 33 ++++-- .../RackAwareRoundRobinPolicyTest.java | 107 ++++++++++++++++++ manual/load_balancing/README.md | 7 ++ 3 files changed, 138 insertions(+), 9 deletions(-) diff --git a/driver-core/src/main/java/com/datastax/driver/core/policies/RackAwareRoundRobinPolicy.java b/driver-core/src/main/java/com/datastax/driver/core/policies/RackAwareRoundRobinPolicy.java index 92e02cc0edd..4730948b388 100644 --- a/driver-core/src/main/java/com/datastax/driver/core/policies/RackAwareRoundRobinPolicy.java +++ b/driver-core/src/main/java/com/datastax/driver/core/policies/RackAwareRoundRobinPolicy.java @@ -56,6 +56,11 @@ * but those are always tried after the local nodes. In other words, this policy guarantees that no * host in a remote data center will be queried unless no host in the local data center can be * reached. + * + *

    For LWT (Lightweight Transaction) queries (where {@link Statement#isLWT()} returns {@code + * true}), the policy skips local rack prioritization and treats all hosts in the local datacenter + * equally, distributing queries in round-robin fashion across the entire local DC. Remote + * datacenters are still only used as fallback after all local DC hosts have been tried. */ public class RackAwareRoundRobinPolicy implements LoadBalancingPolicy { @@ -73,11 +78,11 @@ public static Builder builder() { private static final String UNSET = ""; private final ConcurrentMap> perDcLiveHosts = - new ConcurrentHashMap>(); - private final CopyOnWriteArrayList liveHostsLocalRackLocalDC = - new CopyOnWriteArrayList(); + new ConcurrentHashMap<>(); + private final CopyOnWriteArrayList liveHostsAllLocalDC = new CopyOnWriteArrayList<>(); + private final CopyOnWriteArrayList liveHostsLocalRackLocalDC = new CopyOnWriteArrayList<>(); private final CopyOnWriteArrayList liveHostsRemoteRacksLocalDC = - new CopyOnWriteArrayList(); + new CopyOnWriteArrayList<>(); @VisibleForTesting final AtomicInteger index = new AtomicInteger(); @VisibleForTesting volatile String localDc; @@ -147,6 +152,7 @@ public void init(Cluster cluster, Collection hosts) { else prev.addIfAbsent(host); if (dc.equals(localDc)) { + liveHostsAllLocalDC.add(host); if (rack.equals(localRack)) { liveHostsLocalRackLocalDC.add(host); } else { @@ -240,10 +246,17 @@ public HostDistance distance(Host host) { @Override public Iterator newQueryPlan(String loggedKeyspace, final Statement statement) { - CopyOnWriteArrayList localLiveHosts = perDcLiveHosts.get(localDc); - // Clone for thread safety - final List copyLiveHostsLocalRackLocalDC = cloneList(liveHostsLocalRackLocalDC); - final List copyLiveHostsRemoteRacksLocalDC = cloneList(liveHostsRemoteRacksLocalDC); + // For LWT queries, skip rack prioritization and use all local DC hosts equally + final boolean isLWT = statement != null && statement.isLWT(); + + // For LWT queries, include all local DC hosts in the first part of the plan, not just those in + // the local rack + final List copyLiveHostsLocalRackLocalDC = + isLWT ? cloneList(liveHostsAllLocalDC) : cloneList(liveHostsLocalRackLocalDC); + // For LWT queries, skip the second part of the plan that includes hosts in remote racks of the + // local DC + final List copyLiveHostsRemoteRacksLocalDC = + isLWT ? Collections.emptyList() : cloneList(liveHostsRemoteRacksLocalDC); final int startIdx = index.getAndIncrement(); return new AbstractIterator() { @@ -288,7 +301,7 @@ protected Host computeNext() { } ConsistencyLevel cl = - statement.getConsistencyLevel() == null + statement == null || statement.getConsistencyLevel() == null ? configuration.getQueryOptions().getConsistencyLevel() : statement.getConsistencyLevel(); @@ -348,6 +361,7 @@ public void onUp(Host host) { dcHosts.addIfAbsent(host); if (dc.equals(localDc)) { + liveHostsAllLocalDC.addIfAbsent(host); if (rack.equals(localRack)) { liveHostsLocalRackLocalDC.add(host); } else { @@ -365,6 +379,7 @@ public void onDown(Host host) { if (dcHosts != null) dcHosts.remove(host); if (dc.equals(localDc)) { + liveHostsAllLocalDC.remove(host); if (rack.equals(localRack)) { liveHostsLocalRackLocalDC.remove(host); } else { diff --git a/driver-core/src/test/java/com/datastax/driver/core/policies/RackAwareRoundRobinPolicyTest.java b/driver-core/src/test/java/com/datastax/driver/core/policies/RackAwareRoundRobinPolicyTest.java index c7a33219257..b23b72571ac 100644 --- a/driver-core/src/test/java/com/datastax/driver/core/policies/RackAwareRoundRobinPolicyTest.java +++ b/driver-core/src/test/java/com/datastax/driver/core/policies/RackAwareRoundRobinPolicyTest.java @@ -86,11 +86,14 @@ public void setUpUnitTests() { cluster = mock(Cluster.class); Configuration configuration = mock(Configuration.class); ProtocolOptions protocolOptions = mock(ProtocolOptions.class); + QueryOptions queryOptions = mock(QueryOptions.class); Metadata metadata = mock(Metadata.class); childPolicy = mock(LoadBalancingPolicy.class); when(cluster.getConfiguration()).thenReturn(configuration); when(configuration.getCodecRegistry()).thenReturn(codecRegistry); when(configuration.getProtocolOptions()).thenReturn(protocolOptions); + when(configuration.getQueryOptions()).thenReturn(queryOptions); + when(queryOptions.getConsistencyLevel()).thenReturn(ConsistencyLevel.ONE); when(protocolOptions.getProtocolVersion()).thenReturn(ProtocolVersion.DEFAULT); when(cluster.getMetadata()).thenReturn(metadata); when(host1.isUp()).thenReturn(true); @@ -1107,6 +1110,110 @@ public void should_follow_configuration_on_query_planning( .containsExactly(queryPlanForNonLocalConsistencyLevel2.toArray(new Host[0])); } + /** + * Ensures that {@link RackAwareRoundRobinPolicy} skips rack prioritization for LWT queries, + * treating all local DC hosts equally while still prioritizing local DC over remote DC. + * + * @test_category load_balancing:rack_aware,lwt + */ + @Test(groups = "unit") + public void should_skip_rack_prioritization_for_lwt_queries() { + // given: a policy with 4 local DC hosts (2 in local rack, 2 in remote rack) and 2 remote DC + // hosts + // Initialize hosts in a mixed order: remoteRack, localRack, remoteRack, localRack + // This ensures that when LWT skips rack prioritization, we get a different order + // than the rack-aware order + RackAwareRoundRobinPolicy policy = + new RackAwareRoundRobinPolicy("localDC", "localRack", 1, false, false, false); + policy.init(cluster, ImmutableList.of(host3, host1, host4, host2, host5, host6)); + + // Create a mock LWT statement + Statement lwtStatement = mock(Statement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getConsistencyLevel()).thenReturn(ConsistencyLevel.ONE); + + // when: generating query plans for LWT queries + policy.index.set(0); + List queryPlan1 = Lists.newArrayList(policy.newQueryPlan("keyspace", lwtStatement)); + List queryPlan2 = Lists.newArrayList(policy.newQueryPlan("keyspace", lwtStatement)); + + // then: all 4 local DC hosts should appear before any remote DC host (no rack prioritization) + Assertions.assertThat(queryPlan1.subList(0, 4)).containsOnly(host1, host2, host3, host4); + Assertions.assertThat(queryPlan2.subList(0, 4)).containsOnly(host1, host2, host3, host4); + + // then: remote DC hosts should appear after all local DC hosts + Assertions.assertThat(queryPlan1.subList(4, 5)).containsOnly(host5); + Assertions.assertThat(queryPlan2.subList(4, 5)).containsOnly(host5); + + // then: for LWT queries, order should follow insertion order (host3, host1, host4, host2) + // not rack-aware order (host1, host2, host3, host4) + Assertions.assertThat(queryPlan1).startsWith(host3); + Assertions.assertThat(queryPlan2).startsWith(host1); + } + + /** + * Ensures that {@link RackAwareRoundRobinPolicy} preserves rack-aware routing for non-LWT + * queries. + * + * @test_category load_balancing:rack_aware + */ + @Test(groups = "unit") + public void should_preserve_rack_aware_routing_for_non_lwt_queries() { + // given: a policy with 4 local DC hosts (2 in local rack, 2 in remote rack) and 2 remote DC + // hosts + // Initialize hosts in a mixed order to ensure rack-aware routing reorganizes them + RackAwareRoundRobinPolicy policy = + new RackAwareRoundRobinPolicy("localDC", "localRack", 1, false, false, false); + policy.init(cluster, ImmutableList.of(host3, host1, host4, host2, host5, host6)); + + // Create a normal (non-LWT) statement + Statement normalStatement = mock(Statement.class); + when(normalStatement.isLWT()).thenReturn(false); + when(normalStatement.getConsistencyLevel()).thenReturn(ConsistencyLevel.ONE); + + // when: generating query plans for non-LWT queries + policy.index.set(0); + List queryPlan1 = Lists.newArrayList(policy.newQueryPlan("keyspace", normalStatement)); + List queryPlan2 = Lists.newArrayList(policy.newQueryPlan("keyspace", normalStatement)); + + // then: local rack hosts (host1, host2) should appear first regardless of init order + Assertions.assertThat(queryPlan1.subList(0, 2)).containsOnly(host1, host2); + Assertions.assertThat(queryPlan2.subList(0, 2)).containsOnly(host1, host2); + + // then: remote rack local DC hosts (host3, host4) should appear next + Assertions.assertThat(queryPlan1.subList(2, 4)).containsOnly(host3, host4); + Assertions.assertThat(queryPlan2.subList(2, 4)).containsOnly(host3, host4); + + // then: remote DC hosts should appear last + Assertions.assertThat(queryPlan1.subList(4, 5)).containsOnly(host5); + Assertions.assertThat(queryPlan2.subList(4, 5)).containsOnly(host5); + + // then: query plans should follow round-robin pattern within rack boundaries + Assertions.assertThat(queryPlan1).startsWith(host1); + Assertions.assertThat(queryPlan2).startsWith(host2); + } + + /** + * Ensures that {@link RackAwareRoundRobinPolicy} handles null statement correctly. + * + * @test_category load_balancing:rack_aware + */ + @Test(groups = "unit") + public void should_handle_null_statement() { + // given: a policy with hosts in local and remote DC + RackAwareRoundRobinPolicy policy = + new RackAwareRoundRobinPolicy("localDC", "localRack", 1, false, false, false); + policy.init(cluster, ImmutableList.of(host1, host2, host3, host4, host5, host6)); + + // when: generating query plan with null statement + policy.index.set(0); + List queryPlan = Lists.newArrayList(policy.newQueryPlan("keyspace", null)); + + // then: should use rack-aware routing (default behavior for non-LWT) + // Local rack hosts should appear first + Assertions.assertThat(queryPlan.subList(0, 2)).containsOnly(host1, host2); + } + @DataProvider(name = "distanceTestCases") public Object[][] distanceTestCases() { return new Object[][] { diff --git a/manual/load_balancing/README.md b/manual/load_balancing/README.md index 07d5b3dcbb2..32f4ae05499 100644 --- a/manual/load_balancing/README.md +++ b/manual/load_balancing/README.md @@ -138,6 +138,13 @@ local datacenter and rack. In general, providing the datacenter and rack name ex Hosts belonging to the local datacenter are at distance `LOCAL`, and appear first in query plans (in a round-robin fashion) with hosts in the local rack having precedence over nodes in remote racks in the local datacenter. +**LWT (Lightweight Transaction) Behavior:** For LWT queries (`Statement.isLWT()` returns true), the policy does not +prioritize the local rack. Instead, it round-robins evenly across all hosts in the local datacenter first, then falls +back to remote datacenters (if enabled). This design avoids creating rack-level hotspots during Paxos consensus phases. +LWT queries involve multiple rounds of coordination between replicas, and concentrating these operations on a single +rack can create contention that degrades performance. By distributing LWT load across the entire local datacenter, +the driver achieves better throughput while maintaining low latency through datacenter locality. + For example, if there are any UP hosts in the local rack the policy will query those nodes in round-robin fashion: * query 1: host1 *(local DC, local rack)*, host2 *(local DC, local rack)*, host3 *(local DC, local rack)* * query 2: host2 *(local DC, local rack)*, host3 *(local DC, local rack)*, host1 *(local DC, local rack)* From a5f4e7ae8ea2d55b9fa41efadf44cdd93e4c3a00 Mon Sep 17 00:00:00 2001 From: Mikita Hradovich Date: Tue, 3 Feb 2026 00:54:03 +0100 Subject: [PATCH 4/4] =?UTF-8?q?feat:=20Preserve=20deterministic=20routing?= =?UTF-8?q?=20for=20LWT=20queries=20in=20`LatencyAwarePolicy`.=20=E2=8F=B1?= =?UTF-8?q?=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 14 ++++-- .../core/policies/LatencyAwarePolicy.java | 14 +++++- .../core/policies/LatencyAwarePolicyTest.java | 49 +++++++++++++++++++ 3 files changed, 73 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 928a48fccf2..5ea20557f82 100644 --- a/README.md +++ b/README.md @@ -19,9 +19,17 @@ The Scylla Java Driver is a fork from [DataStax Java Driver](https://github.com/ * Like all Scylla Drivers, the Scylla Java Driver is **Shard Aware** and contains extensions for a `tokenAwareHostPolicy`. Using this policy, the driver can select a connection to a particular shard based on the shard's token. As a result, latency is significantly reduced because there is no need to pass data between the shards. -* **Lightweight Transaction (LWT) Optimization**: when using `TokenAwarePolicy` with prepared statements, - LWT queries automatically use replica-only routing, prioritizing local datacenter replicas to minimize - coordinator forwarding overhead and reduce contention during Paxos consensus phases. +* **Lightweight Transaction (LWT) Optimization**: + - When using `TokenAwarePolicy` with prepared statements, LWT queries automatically use replica-only routing, + prioritizing local datacenter replicas to minimize coordinator forwarding overhead and reduce contention during + Paxos consensus phases. + - When using `RackAwareRoundRobinPolicy`, LWT queries skip local rack prioritization and distribute evenly across + all hosts in the local datacenter. This avoids creating rack-level hotspots during Paxos consensus, which can + lead to increased contention and reduced throughput. The local datacenter is still prioritized over remote + datacenters to maintain low latency. + - When using `LatencyAwarePolicy`, LWT queries bypass latency-based reordering to preserve deterministic replica + selection. This ensures that LWT routing assumptions (such as consistent coordinator selection for optimal Paxos + performance) are maintained throughout the policy chain. * [Sync](manual/) and [Async](manual/async/) API * [Simple](manual/statements/simple/), [Prepared](manual/statements/prepared/), and [Batch](manual/statements/batch/) statements diff --git a/driver-core/src/main/java/com/datastax/driver/core/policies/LatencyAwarePolicy.java b/driver-core/src/main/java/com/datastax/driver/core/policies/LatencyAwarePolicy.java index 34152601258..c31f0875ca3 100644 --- a/driver-core/src/main/java/com/datastax/driver/core/policies/LatencyAwarePolicy.java +++ b/driver-core/src/main/java/com/datastax/driver/core/policies/LatencyAwarePolicy.java @@ -62,6 +62,11 @@ * they will only be tried if all other nodes failed). Note that this policy only penalizes slow * nodes, it does not globally sort the query plan by latency. * + *

    LWT statements: if {@link Statement#isLWT()} returns {@code true}, this + * policy does not apply latency-based reordering and returns the child policy's query plan as-is. + * This is to preserve LWT-specific routing assumptions (for example deterministic replica selection + * when using {@link TokenAwarePolicy}). + * *

    The latency score for a given node is a based on a form of exponential moving * average. In other words, the latency score of a node is the average of its previously @@ -145,7 +150,7 @@ public void run() { if (logger.isDebugEnabled()) { /* * For users to be able to know if the policy potentially needs tuning, we need to provide - * some feedback on on how things evolve. For that, we use the min computation to also check + * some feedback on how things evolve. For that, we use the min computation to also check * which host will be excluded if a query is submitted now and if any host is, we log it (but * we try to avoid flooding too). This is probably interesting information anyway since it * gets an idea of which host perform badly. @@ -253,6 +258,13 @@ public HostDistance distance(Host host) { */ @Override public Iterator newQueryPlan(String loggedKeyspace, Statement statement) { + // For LWT queries, preserve the child policy's ordering. + // LWT routing can rely on deterministic replica ordering (e.g. by TokenAwarePolicy), and + // latency-based reordering can undermine those assumptions. + if (statement != null && statement.isLWT()) { + return childPolicy.newQueryPlan(loggedKeyspace, statement); + } + final Iterator childIter = childPolicy.newQueryPlan(loggedKeyspace, statement); return new AbstractIterator() { diff --git a/driver-core/src/test/java/com/datastax/driver/core/policies/LatencyAwarePolicyTest.java b/driver-core/src/test/java/com/datastax/driver/core/policies/LatencyAwarePolicyTest.java index ebbeb686e48..e64d390864c 100644 --- a/driver-core/src/test/java/com/datastax/driver/core/policies/LatencyAwarePolicyTest.java +++ b/driver-core/src/test/java/com/datastax/driver/core/policies/LatencyAwarePolicyTest.java @@ -28,10 +28,13 @@ import com.datastax.driver.core.LatencyTracker; import com.datastax.driver.core.ScassandraTestBase; import com.datastax.driver.core.Session; +import com.datastax.driver.core.SimpleStatement; import com.datastax.driver.core.Statement; import com.datastax.driver.core.exceptions.NoHostAvailableException; import com.datastax.driver.core.exceptions.ReadTimeoutException; import com.datastax.driver.core.exceptions.UnavailableException; +import com.google.common.collect.Lists; +import java.util.Iterator; import java.util.concurrent.CountDownLatch; import org.testng.annotations.Test; @@ -178,4 +181,50 @@ public void should_consider_latency_when_read_timeout() throws Exception { cluster.close(); } } + + @Test(groups = "short") + public void should_not_reorder_query_plan_for_lwt_queries() throws Exception { + // given + String query = "SELECT foo FROM bar"; + primingClient.prime(queryBuilder().withQuery(query).build()); + + LatencyAwarePolicy latencyAwarePolicy = + LatencyAwarePolicy.builder(new RoundRobinPolicy()).withMininumMeasurements(1).build(); + + Cluster.Builder builder = super.createClusterBuilder(); + builder.withLoadBalancingPolicy(latencyAwarePolicy); + + Cluster cluster = builder.build(); + try { + cluster.init(); + + // Create an LWT statement so latency-aware policy must preserve child ordering + Statement lwtStatement = + new SimpleStatement(query) { + @Override + public boolean isLWT() { + return true; + } + }; + + // Make a request to populate latency metrics + LatencyTrackerBarrier barrier = new LatencyTrackerBarrier(1); + cluster.register(barrier); + Session session = cluster.connect(); + session.execute(query); + barrier.await(); + latencyAwarePolicy.new Updater().run(); + + // when + Iterator plan1 = latencyAwarePolicy.newQueryPlan("ks", lwtStatement); + Iterator plan2 = latencyAwarePolicy.newQueryPlan("ks", lwtStatement); + + // then + Host host = retrieveSingleHost(cluster); + assertThat(Lists.newArrayList(plan1)).containsExactly(host); + assertThat(Lists.newArrayList(plan2)).containsExactly(host); + } finally { + cluster.close(); + } + } }