Skip to content

Commit ab359ad

Browse files
kmontemayor2-sckmonteclaudekmontemayor
authored
Add CONTIGUOUS shard strategy (#545)
Co-authored-by: kmontemayor <kyle.e.montemayor@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: kmontemayor <kmontemayor@snapchat.com>
1 parent 20ed9e9 commit ab359ad

10 files changed

Lines changed: 1427 additions & 126 deletions

File tree

gigl/distributed/dist_ablp_neighborloader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,12 @@ def _setup_for_graph_store(
666666
# Extract supervision edge types and derive label edge types from the
667667
# ABLPInputNodes.labels dict (keyed by supervision edge type).
668668
self._supervision_edge_types = list(first_input.labels.keys())
669-
has_negatives = any(neg is not None for _, neg in first_input.labels.values())
669+
has_negatives = False
670+
for ablp_input in input_nodes.values():
671+
for maybe_negative_labels in ablp_input.labels.values():
672+
if maybe_negative_labels is not None:
673+
has_negatives = True
674+
break
670675

671676
self._positive_label_edge_types = [
672677
message_passing_to_positive_label(et) for et in self._supervision_edge_types

gigl/distributed/graph_store/dist_server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
FetchABLPInputRequest,
4040
FetchNodesRequest,
4141
)
42+
from gigl.distributed.graph_store.sharding import ServerSlice
4243
from gigl.distributed.sampler import ABLPNodeSamplerInput
4344
from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions
4445
from gigl.distributed.utils.neighborloader import shard_nodes_by_process
@@ -283,7 +284,7 @@ def get_node_ids(
283284
284285
Args:
285286
request: The node-fetch request, including split, node type,
286-
and round-robin rank/world_size.
287+
and either round-robin rank/world_size or a contiguous slice.
287288
288289
Returns:
289290
The node ids.
@@ -306,6 +307,7 @@ def get_node_ids(
306307
node_type=request.node_type,
307308
rank=request.rank,
308309
world_size=request.world_size,
310+
server_slice=request.server_slice,
309311
)
310312

311313
def _get_node_ids(
@@ -314,6 +316,7 @@ def _get_node_ids(
314316
node_type: Optional[NodeType],
315317
rank: Optional[int] = None,
316318
world_size: Optional[int] = None,
319+
server_slice: Optional[ServerSlice] = None,
317320
) -> torch.Tensor:
318321
"""Core implementation for fetching node IDs by split, type, and sharding.
319322
@@ -366,6 +369,8 @@ def _get_node_ids(
366369
f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}."
367370
)
368371

372+
if server_slice is not None:
373+
return server_slice.slice_tensor(nodes)
369374
if rank is not None and world_size is not None:
370375
return shard_nodes_by_process(nodes, rank, world_size)
371376
return nodes
@@ -420,6 +425,7 @@ def get_ablp_input(
420425
node_type=request.node_type,
421426
rank=request.rank,
422427
world_size=request.world_size,
428+
server_slice=request.server_slice,
423429
)
424430
positive_label_edge_type, negative_label_edge_type = select_label_edge_types(
425431
request.supervision_edge_type, self.dataset.get_edge_types()

gigl/distributed/graph_store/messages.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dataclasses import dataclass
44
from typing import Literal, Optional, Union
55

6+
from gigl.distributed.graph_store.sharding import ServerSlice
67
from gigl.src.common.types.graph_data import EdgeType, NodeType
78

89

@@ -17,6 +18,9 @@ class FetchNodesRequest:
1718
Must be provided together with ``rank``.
1819
split: The split of the dataset to get node ids from.
1920
node_type: The type of nodes to get node ids for.
21+
server_slice: An optional :class:`~gigl.distributed.graph_store.sharding.ServerSlice`
22+
describing the fraction of this server's data to return.
23+
Cannot be combined with ``rank``/``world_size``.
2024
2125
Examples:
2226
Fetch all nodes without sharding:
@@ -36,18 +40,25 @@ class FetchNodesRequest:
3640
world_size: Optional[int] = None
3741
split: Optional[Union[Literal["train", "val", "test"], str]] = None
3842
node_type: Optional[NodeType] = None
43+
server_slice: Optional[ServerSlice] = None
3944

4045
def validate(self) -> None:
4146
"""Validate that the request has consistent rank/world_size.
4247
4348
Raises:
44-
ValueError: If only one of ``rank`` or ``world_size`` is provided.
49+
ValueError:
50+
If only one of ``rank`` or ``world_size`` is provided.
51+
If ``server_slice`` is provided together with ``rank`` or ``world_size``.
4552
"""
4653
if (self.rank is None) ^ (self.world_size is None):
4754
raise ValueError(
4855
"rank and world_size must be provided together. "
4956
f"Received rank={self.rank}, world_size={self.world_size}"
5057
)
58+
if self.server_slice is not None and (
59+
self.rank is not None or self.world_size is not None
60+
):
61+
raise ValueError("server_slice cannot be combined with rank/world_size.")
5162

5263

5364
@dataclass(frozen=True)
@@ -62,6 +73,9 @@ class FetchABLPInputRequest:
6273
Must be provided together with ``world_size``.
6374
world_size: The total number of processes in the distributed setup.
6475
Must be provided together with ``rank``.
76+
server_slice: An optional :class:`~gigl.distributed.graph_store.sharding.ServerSlice`
77+
describing the fraction of this server's data to return.
78+
Cannot be combined with ``rank``/``world_size``.
6579
6680
Examples:
6781
Fetch training ABLP input without sharding:
@@ -78,15 +92,22 @@ class FetchABLPInputRequest:
7892
supervision_edge_type: EdgeType
7993
rank: Optional[int] = None
8094
world_size: Optional[int] = None
95+
server_slice: Optional[ServerSlice] = None
8196

8297
def validate(self) -> None:
8398
"""Validate that the request has consistent rank/world_size.
8499
85100
Raises:
86-
ValueError: If only one of ``rank`` or ``world_size`` is provided.
101+
ValueError:
102+
If only one of ``rank`` or ``world_size`` is provided.
103+
If ``server_slice`` is provided together with ``rank`` or ``world_size``.
87104
"""
88105
if (self.rank is None) ^ (self.world_size is None):
89106
raise ValueError(
90107
"rank and world_size must be provided together. "
91108
f"Received rank={self.rank}, world_size={self.world_size}"
92109
)
110+
if self.server_slice is not None and (
111+
self.rank is not None or self.world_size is not None
112+
):
113+
raise ValueError("server_slice cannot be combined with rank/world_size.")

0 commit comments

Comments
 (0)