33from dataclasses import dataclass
44from typing import Literal , Optional , Union
55
6+ from gigl .distributed .graph_store .sharding import ServerSlice
67from gigl .src .common .types .graph_data import EdgeType , NodeType
78
89
@@ -17,6 +18,9 @@ class FetchNodesRequest:
1718 Must be provided together with ``rank``.
1819 split: The split of the dataset to get node ids from.
1920 node_type: The type of nodes to get node ids for.
21+ server_slice: An optional :class:`~gigl.distributed.graph_store.sharding.ServerSlice`
22+ describing the fraction of this server's data to return.
23+ Cannot be combined with ``rank``/``world_size``.
2024
2125 Examples:
2226 Fetch all nodes without sharding:
@@ -36,18 +40,25 @@ class FetchNodesRequest:
3640 world_size : Optional [int ] = None
3741 split : Optional [Union [Literal ["train" , "val" , "test" ], str ]] = None
3842 node_type : Optional [NodeType ] = None
43+ server_slice : Optional [ServerSlice ] = None
3944
4045 def validate (self ) -> None :
4146 """Validate that the request has consistent rank/world_size.
4247
4348 Raises:
44- ValueError: If only one of ``rank`` or ``world_size`` is provided.
49+ ValueError:
50+ If only one of ``rank`` or ``world_size`` is provided.
51+ If ``server_slice`` is provided together with ``rank`` or ``world_size``.
4552 """
4653 if (self .rank is None ) ^ (self .world_size is None ):
4754 raise ValueError (
4855 "rank and world_size must be provided together. "
4956 f"Received rank={ self .rank } , world_size={ self .world_size } "
5057 )
58+ if self .server_slice is not None and (
59+ self .rank is not None or self .world_size is not None
60+ ):
61+ raise ValueError ("server_slice cannot be combined with rank/world_size." )
5162
5263
5364@dataclass (frozen = True )
@@ -62,6 +73,9 @@ class FetchABLPInputRequest:
6273 Must be provided together with ``world_size``.
6374 world_size: The total number of processes in the distributed setup.
6475 Must be provided together with ``rank``.
76+ server_slice: An optional :class:`~gigl.distributed.graph_store.sharding.ServerSlice`
77+ describing the fraction of this server's data to return.
78+ Cannot be combined with ``rank``/``world_size``.
6579
6680 Examples:
6781 Fetch training ABLP input without sharding:
@@ -78,15 +92,22 @@ class FetchABLPInputRequest:
7892 supervision_edge_type : EdgeType
7993 rank : Optional [int ] = None
8094 world_size : Optional [int ] = None
95+ server_slice : Optional [ServerSlice ] = None
8196
8297 def validate (self ) -> None :
8398 """Validate that the request has consistent rank/world_size.
8499
85100 Raises:
86- ValueError: If only one of ``rank`` or ``world_size`` is provided.
101+ ValueError:
102+ If only one of ``rank`` or ``world_size`` is provided.
103+ If ``server_slice`` is provided together with ``rank`` or ``world_size``.
87104 """
88105 if (self .rank is None ) ^ (self .world_size is None ):
89106 raise ValueError (
90107 "rank and world_size must be provided together. "
91108 f"Received rank={ self .rank } , world_size={ self .world_size } "
92109 )
110+ if self .server_slice is not None and (
111+ self .rank is not None or self .world_size is not None
112+ ):
113+ raise ValueError ("server_slice cannot be combined with rank/world_size." )
0 commit comments