From ed818c2161a649ad37a31689154514e840efe795 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 19 May 2026 01:00:18 +0000 Subject: [PATCH 01/32] potential fix --- .../graph_store/shared_dist_sampling_producer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index 0f7461196..f7852c6f6 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -103,6 +103,7 @@ SamplerRuntime, create_dist_sampler, ) +from gigl.utils.share_memory import share_memory logger = Logger() @@ -871,7 +872,13 @@ def __init__( self._completed_workers: defaultdict[tuple[int, int], set[int]] = defaultdict( set ) + # Move degree tensors to shared memory before workers are spawned so + # each worker maps the same allocation instead of pickling a private copy. + # In colocated mode this is handled by DistDataset.to_ipc_handle(); here + # the tensors arrive via RPC from the storage server and are not yet in + # shared memory, causing num_workers copies without this call. self._degree_tensors = degree_tensors + share_memory(self._degree_tensors) def init_backend(self) -> None: """Initialize worker processes once for this backend. From abb8e569dcc537566b817dc4d521c6a20eadc571 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 19 May 2026 01:22:04 +0000 Subject: [PATCH 02/32] Update --- gigl/distributed/dist_ppr_sampler.py | 156 +++++++++++++----- .../shared_dist_sampling_producer.py | 44 ++++- 2 files changed, 149 insertions(+), 51 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 402e381c1..9aaefbfa1 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -37,6 +37,92 @@ ) +def build_ppr_node_type_to_edge_types( + is_homogeneous: bool, + edge_types: list[EdgeType], + edge_dir: str, +) -> dict[NodeType, list[EdgeType]]: + """Build the node-type → edge-types mapping used by the PPR forward-push kernel. + + For homogeneous graphs returns the singleton sentinel mapping. For + heterogeneous graphs, groups non-label edge types by their anchor node type + (destination for ``edge_dir="in"``, source for ``edge_dir="out"``). + + Args: + is_homogeneous: True if the graph has a single node/edge type. + edge_types: All edge types present in the graph (ignored when homogeneous). + edge_dir: Sampling direction — ``"in"`` or ``"out"``. + + Returns: + Dict mapping each anchor NodeType to the list of EdgeTypes traversable + from it during a PPR walk. + """ + if is_homogeneous: + return {_PPR_HOMOGENEOUS_NODE_TYPE: [_PPR_HOMOGENEOUS_EDGE_TYPE]} + + node_type_to_edge_types: dict[NodeType, list[EdgeType]] = defaultdict(list) + for etype in edge_types: + if is_label_edge_type(etype): + continue + anchor_type = etype[-1] if edge_dir == "in" else etype[0] + node_type_to_edge_types[anchor_type].append(etype) + return dict(node_type_to_edge_types) + + +def build_ppr_total_degree_tensors( + degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], + dtype: torch.dtype, + node_type_to_edge_types: dict[NodeType, list[EdgeType]], +) -> dict[NodeType, torch.Tensor]: + """Pre-compute total-degree tensors for the PPR forward-push kernel. + + For homogeneous graphs converts the single degree tensor to ``dtype``. + For heterogeneous graphs sums per-edge-type degrees into a per-node-type + total, padding shorter tensors with zeros where node counts differ. + + This function is intentionally standalone so it can be called once in the + parent process (and the result shared across workers) rather than redundantly + inside each worker's ``DistPPRNeighborSampler.__init__``. + + Args: + degree_tensors: Per-edge-type degree tensors (homogeneous: single + ``torch.Tensor``; heterogeneous: ``dict[EdgeType, torch.Tensor]``). + dtype: Target dtype for the output tensors. + node_type_to_edge_types: Mapping from anchor NodeType to the list of + EdgeTypes traversable from it, as returned by + :func:`build_ppr_node_type_to_edge_types`. + + Returns: + Dict mapping NodeType to a 1-D total-degree tensor of shape + ``[num_nodes_of_that_type]`` with dtype ``dtype``. + + Raises: + ValueError: If a required edge type is missing from ``degree_tensors``. + """ + result: dict[NodeType, torch.Tensor] = {} + + if isinstance(degree_tensors, torch.Tensor): + result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(dtype) + else: + dtype_max = torch.iinfo(dtype).max + for node_type, edge_types in node_type_to_edge_types.items(): + max_len = 0 + for et in edge_types: + if et not in degree_tensors: + raise ValueError( + f"Edge type {et} not found in degree tensors. " + f"Available: {list(degree_tensors.keys())}" + ) + max_len = max(max_len, len(degree_tensors[et])) + summed = torch.zeros(max_len, dtype=torch.int64) + for et in edge_types: + et_degrees = degree_tensors[et] + summed[: len(et_degrees)] += et_degrees.to(torch.int64) + result[node_type] = summed.clamp(max=dtype_max).to(dtype) + + return result + + class DistPPRNeighborSampler(BaseDistNeighborSampler): """Personalized PageRank (PPR) based distributed neighbor sampler. @@ -134,14 +220,26 @@ def __init__( # edge types traversable from that node type. This is a graph-level # property used on every PPR iteration, so computing it once at init # avoids per-node summation and cache lookups in the hot loop. - # TODO (mkolodner-sc): This trades memory for throughput — we - # materialize a tensor per node type to avoid recomputing total degree - # on every neighbor during sampling. Computing it here (rather than in - # the dataset) also keeps the door open for edge-specific degree - # strategies. If memory becomes a bottleneck, revisit this. - self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = ( - self._build_total_degree_tensors(degree_tensors, total_degree_dtype) - ) + # + # In graph-store mode, SharedDistSamplingProducer pre-computes the + # total-degree dict once in the parent process, moves it to shared + # memory, and passes it here as degree_tensors (keys are NodeType + # strings). In colocated mode degree_tensors arrives as raw + # per-edge-type tensors (keys are EdgeType tuples, or a bare Tensor + # for homogeneous graphs) and we compute the total here. + if ( + isinstance(degree_tensors, dict) + and degree_tensors + and not isinstance(next(iter(degree_tensors)), tuple) + ): + # Already the pre-computed total (NodeType string keys). + self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = ( + degree_tensors + ) + else: + self._node_type_to_total_degree = self._build_total_degree_tensors( + degree_tensors, total_degree_dtype + ) # Build integer ID mappings for the C++ forward-push kernel. String # NodeType / EdgeType keys are only used at the Python boundary @@ -198,9 +296,7 @@ def _build_total_degree_tensors( ) -> dict[NodeType, torch.Tensor]: """Build total-degree tensors by summing per-edge-type degrees for each node type. - For homogeneous graphs, the total degree is just the single degree tensor. - For heterogeneous graphs, it sums degree tensors across all edge types - traversable from each node type, padding shorter tensors with zeros. + Delegates to the module-level :func:`build_ppr_total_degree_tensors`. Args: degree_tensors: Per-edge-type degree tensors from the dataset. @@ -209,39 +305,11 @@ def _build_total_degree_tensors( Returns: Dict mapping node type to a 1-D tensor of total degrees. """ - result: dict[NodeType, torch.Tensor] = {} - - if self._is_homogeneous: - assert isinstance(degree_tensors, torch.Tensor) - # Single edge type: degree values fit directly in the target dtype. - result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(dtype) - else: - assert isinstance(degree_tensors, dict) - dtype_max = torch.iinfo(dtype).max - for node_type, edge_types in self._node_type_to_edge_types.items(): - max_len = 0 - for et in edge_types: - if et not in degree_tensors: - raise ValueError( - f"Edge type {et} not found in degree tensors. " - f"Available: {list(degree_tensors.keys())}" - ) - max_len = max(max_len, len(degree_tensors[et])) - - # Each degree tensor is indexed by node ID (derived from CSR - # indptr), so index i in every edge type's tensor refers to - # the same node. Element-wise summation gives the total degree - # per node across all edge types. Shorter tensors are padded - # implicitly (only the first len(et_degrees) entries are added). - # Sum in int64: aggregate degrees are bounded by partition size - # and fit comfortably within int64 range in practice. - summed = torch.zeros(max_len, dtype=torch.int64) - for et in edge_types: - et_degrees = degree_tensors[et] - summed[: len(et_degrees)] += et_degrees.to(torch.int64) - result[node_type] = summed.clamp(max=dtype_max).to(dtype) - - return result + return build_ppr_total_degree_tensors( + degree_tensors=degree_tensors, + dtype=dtype, + node_type_to_edge_types=self._node_type_to_edge_types, + ) def _get_destination_type(self, edge_type: EdgeType) -> NodeType: """Get the node type at the destination end of an edge type.""" diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index f7852c6f6..b7838c02c 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -97,7 +97,11 @@ from torch._C import _set_worker_signal_handlers from gigl.common.logger import Logger -from gigl.distributed.sampler_options import SamplerOptions +from gigl.distributed.dist_ppr_sampler import ( + build_ppr_node_type_to_edge_types, + build_ppr_total_degree_tensors, +) +from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.distributed.utils.dist_sampler import ( SamplerInput, SamplerRuntime, @@ -872,12 +876,38 @@ def __init__( self._completed_workers: defaultdict[tuple[int, int], set[int]] = defaultdict( set ) - # Move degree tensors to shared memory before workers are spawned so - # each worker maps the same allocation instead of pickling a private copy. - # In colocated mode this is handled by DistDataset.to_ipc_handle(); here - # the tensors arrive via RPC from the storage server and are not yet in - # shared memory, causing num_workers copies without this call. - self._degree_tensors = degree_tensors + # For PPR sampling, pre-compute the total-degree dict (summed across edge + # types, converted to the target dtype) once here in the parent process. + # Workers receive the result directly as degree_tensors and skip the + # per-worker summation in DistPPRNeighborSampler._build_total_degree_tensors. + # + # Then move to shared memory so all spawned workers map the same + # allocation instead of each pickling a private copy. In colocated mode + # DistDataset.to_ipc_handle() handles shared memory; here the tensors + # arrive via RPC and are plain heap allocations without this call. + if ( + isinstance(sampler_options, PPRSamplerOptions) + and degree_tensors is not None + ): + assert data.graph is not None, ( + "DistDataset.graph must be set for PPR sampling" + ) + is_homogeneous = not isinstance(data.graph, dict) + edge_types = list(data.graph.keys()) if isinstance(data.graph, dict) else [] + node_type_to_edge_types = build_ppr_node_type_to_edge_types( + is_homogeneous=is_homogeneous, + edge_types=edge_types, + edge_dir=data.edge_dir, + ) + self._degree_tensors: Optional[ + Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + ] = build_ppr_total_degree_tensors( + degree_tensors=degree_tensors, + dtype=sampler_options.total_degree_dtype, + node_type_to_edge_types=node_type_to_edge_types, + ) + else: + self._degree_tensors = degree_tensors share_memory(self._degree_tensors) def init_backend(self) -> None: From a0e84fab04f6811353c8f5737a3560743134c883 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 19 May 2026 01:38:25 +0000 Subject: [PATCH 03/32] Update --- gigl/distributed/base_dist_loader.py | 36 ++++++--- gigl/distributed/dist_ppr_sampler.py | 81 +++++-------------- gigl/distributed/dist_sampling_producer.py | 8 +- .../shared_dist_sampling_producer.py | 17 ++-- gigl/distributed/sampler_options.py | 5 -- gigl/distributed/utils/dist_sampler.py | 5 +- 6 files changed, 58 insertions(+), 94 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 203c8520d..4e39273c5 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -39,6 +39,10 @@ from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.dist_ppr_sampler import ( + build_ppr_node_type_to_edge_types, + build_ppr_total_degree_tensors, +) from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.compute import async_request_server from gigl.distributed.graph_store.dist_server import DistServer @@ -425,17 +429,27 @@ def create_mp_producer( """ channel = BaseDistLoader.create_colocated_channel(worker_options) if isinstance(sampler_options, PPRSamplerOptions): - degree_tensors = dataset.degree_tensor - if isinstance(degree_tensors, dict): - logger.info( - f"Pre-computed degree tensors for PPR sampling across " - f"{len(degree_tensors)} edge types." - ) - else: - logger.info( - f"Pre-computed degree tensor for PPR sampling with " - f"{degree_tensors.size(0)} nodes." - ) + assert dataset.graph is not None, ( + "DistDataset.graph must be set for PPR sampling" + ) + raw_degree_tensors = dataset.degree_tensor + is_homogeneous = not isinstance(dataset.graph, dict) + edge_types = ( + list(dataset.graph.keys()) if isinstance(dataset.graph, dict) else [] + ) + node_type_to_edge_types = build_ppr_node_type_to_edge_types( + is_homogeneous=is_homogeneous, + edge_types=edge_types, + edge_dir=dataset.edge_dir, + ) + degree_tensors = build_ppr_total_degree_tensors( + degree_tensors=raw_degree_tensors, + node_type_to_edge_types=node_type_to_edge_types, + ) + logger.info( + f"Pre-computed total degree tensors for PPR sampling across " + f"{len(degree_tensors)} node types." + ) else: degree_tensors = None return DistSamplingProducer( diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 9aaefbfa1..c6120cffa 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -71,14 +71,14 @@ def build_ppr_node_type_to_edge_types( def build_ppr_total_degree_tensors( degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], - dtype: torch.dtype, node_type_to_edge_types: dict[NodeType, list[EdgeType]], ) -> dict[NodeType, torch.Tensor]: """Pre-compute total-degree tensors for the PPR forward-push kernel. - For homogeneous graphs converts the single degree tensor to ``dtype``. + For homogeneous graphs converts the single degree tensor to int16. For heterogeneous graphs sums per-edge-type degrees into a per-node-type - total, padding shorter tensors with zeros where node counts differ. + total (capped at int16 max), padding shorter tensors with zeros where node + counts differ. This function is intentionally standalone so it can be called once in the parent process (and the result shared across workers) rather than redundantly @@ -87,24 +87,24 @@ def build_ppr_total_degree_tensors( Args: degree_tensors: Per-edge-type degree tensors (homogeneous: single ``torch.Tensor``; heterogeneous: ``dict[EdgeType, torch.Tensor]``). - dtype: Target dtype for the output tensors. node_type_to_edge_types: Mapping from anchor NodeType to the list of EdgeTypes traversable from it, as returned by :func:`build_ppr_node_type_to_edge_types`. Returns: Dict mapping NodeType to a 1-D total-degree tensor of shape - ``[num_nodes_of_that_type]`` with dtype ``dtype``. + ``[num_nodes_of_that_type]`` with dtype ``torch.int16``, capped at + ``torch.iinfo(torch.int16).max``. Raises: ValueError: If a required edge type is missing from ``degree_tensors``. """ + _INT16_MAX = torch.iinfo(torch.int16).max result: dict[NodeType, torch.Tensor] = {} if isinstance(degree_tensors, torch.Tensor): - result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(dtype) + result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(torch.int16) else: - dtype_max = torch.iinfo(dtype).max for node_type, edge_types in node_type_to_edge_types.items(): max_len = 0 for et in edge_types: @@ -118,7 +118,7 @@ def build_ppr_total_degree_tensors( for et in edge_types: et_degrees = degree_tensors[et] summed[: len(et_degrees)] += et_degrees.to(torch.int64) - result[node_type] = summed.clamp(max=dtype_max).to(dtype) + result[node_type] = summed.clamp(max=_INT16_MAX).to(torch.int16) return result @@ -160,10 +160,10 @@ class DistPPRNeighborSampler(BaseDistNeighborSampler): but require more computation. Typical values: 1e-4 to 1e-6. max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores. num_neighbors_per_hop: Maximum number of neighbors to fetch per hop. - total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults - to ``torch.int32``. Use a larger dtype if nodes have exceptionally high - aggregate degrees. - degree_tensors: Pre-computed degree tensors from the dataset. + degree_tensors: Pre-computed total-degree tensors (int16, capped at + int16 max), keyed by NodeType. Must be pre-computed by the caller + (e.g. via :func:`build_ppr_total_degree_tensors`) so that workers + share a single allocation rather than recomputing per-worker. """ def __init__( @@ -173,8 +173,7 @@ def __init__( eps: float = 1e-4, max_ppr_nodes: int = 50, num_neighbors_per_hop: int = 100_000, - total_degree_dtype: torch.dtype = torch.int32, - degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], + degree_tensors: dict[NodeType, torch.Tensor], max_fetch_iterations: Optional[int] = None, **kwargs, ): @@ -216,30 +215,12 @@ def __init__( ] self._is_homogeneous = True - # Precompute total degree per node type: the sum of degrees across all - # edge types traversable from that node type. This is a graph-level - # property used on every PPR iteration, so computing it once at init - # avoids per-node summation and cache lookups in the hot loop. - # - # In graph-store mode, SharedDistSamplingProducer pre-computes the - # total-degree dict once in the parent process, moves it to shared - # memory, and passes it here as degree_tensors (keys are NodeType - # strings). In colocated mode degree_tensors arrives as raw - # per-edge-type tensors (keys are EdgeType tuples, or a bare Tensor - # for homogeneous graphs) and we compute the total here. - if ( - isinstance(degree_tensors, dict) - and degree_tensors - and not isinstance(next(iter(degree_tensors)), tuple) - ): - # Already the pre-computed total (NodeType string keys). - self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = ( - degree_tensors - ) - else: - self._node_type_to_total_degree = self._build_total_degree_tensors( - degree_tensors, total_degree_dtype - ) + # Total-degree tensors keyed by NodeType, pre-computed by the caller. + # Callers (create_mp_producer for colocated, SharedDistSamplingBackend + # for graph-store) run build_ppr_total_degree_tensors once in the parent + # process and place the result in shared memory so all worker processes + # map the same allocation. + self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = degree_tensors # Build integer ID mappings for the C++ forward-push kernel. String # NodeType / EdgeType keys are only used at the Python boundary @@ -285,32 +266,10 @@ def __init__( # Degree tensors indexed by ntype_id. Destination-only types get an empty # tensor; the C++ kernel returns 0 for those, matching _get_total_degree. self._degree_tensors_for_cpp: list[torch.Tensor] = [ - self._node_type_to_total_degree.get(nt, torch.zeros(0, dtype=torch.int32)) + self._node_type_to_total_degree.get(nt, torch.zeros(0, dtype=torch.int16)) for nt in all_node_types ] - def _build_total_degree_tensors( - self, - degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], - dtype: torch.dtype, - ) -> dict[NodeType, torch.Tensor]: - """Build total-degree tensors by summing per-edge-type degrees for each node type. - - Delegates to the module-level :func:`build_ppr_total_degree_tensors`. - - Args: - degree_tensors: Per-edge-type degree tensors from the dataset. - dtype: Dtype for the output tensors. - - Returns: - Dict mapping node type to a 1-D tensor of total degrees. - """ - return build_ppr_total_degree_tensors( - degree_tensors=degree_tensors, - dtype=dtype, - node_type_to_edge_types=self._node_type_to_edge_types, - ) - def _get_destination_type(self, edge_type: EdgeType) -> NodeType: """Get the node type at the destination end of an edge type.""" return edge_type[0] if self.edge_dir == "in" else edge_type[-1] diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 3a51715e2..15d29a48c 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -30,7 +30,7 @@ SamplingConfig, SamplingType, ) -from graphlearn_torch.typing import EdgeType +from graphlearn_torch.typing import NodeType from graphlearn_torch.utils import seed_everything from torch._C import _set_worker_signal_handlers from torch.utils.data.dataloader import DataLoader @@ -55,7 +55,7 @@ def _sampling_worker_loop( sampling_completed_worker_count, # mp.Value mp_barrier: Barrier, sampler_options: SamplerOptions, - degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + degree_tensors: Optional[dict[NodeType, torch.Tensor]], ): dist_sampler = None try: @@ -180,9 +180,7 @@ def __init__( worker_options: MpDistSamplingWorkerOptions, channel: ChannelBase, sampler_options: SamplerOptions, - degree_tensors: Optional[ - Union[torch.Tensor, dict[EdgeType, torch.Tensor]] - ] = None, + degree_tensors: Optional[dict[NodeType, torch.Tensor]] = None, ): super().__init__(data, sampler_input, sampling_config, worker_options, channel) self._sampler_options = sampler_options diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index b7838c02c..6712ac850 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -93,7 +93,7 @@ SamplingConfig, SamplingType, ) -from graphlearn_torch.typing import EdgeType +from graphlearn_torch.typing import EdgeType, NodeType from torch._C import _set_worker_signal_handlers from gigl.common.logger import Logger @@ -343,7 +343,7 @@ def _shared_sampling_worker_loop( event_queue: mp.Queue, mp_barrier: Barrier, sampler_options: SamplerOptions, - degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + degree_tensors: Optional[dict[NodeType, torch.Tensor]], ) -> None: """Run one shared graph-store worker that schedules many input channels. @@ -899,15 +899,14 @@ def __init__( edge_types=edge_types, edge_dir=data.edge_dir, ) - self._degree_tensors: Optional[ - Union[torch.Tensor, dict[EdgeType, torch.Tensor]] - ] = build_ppr_total_degree_tensors( - degree_tensors=degree_tensors, - dtype=sampler_options.total_degree_dtype, - node_type_to_edge_types=node_type_to_edge_types, + self._degree_tensors: Optional[dict[NodeType, torch.Tensor]] = ( + build_ppr_total_degree_tensors( + degree_tensors=degree_tensors, + node_type_to_edge_types=node_type_to_edge_types, + ) ) else: - self._degree_tensors = degree_tensors + self._degree_tensors = None share_memory(self._degree_tensors) def init_backend(self) -> None: diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index fccd7a3ba..08cd27352 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -10,7 +10,6 @@ from dataclasses import dataclass from typing import Optional, Union -import torch from graphlearn_torch.typing import EdgeType from gigl.common.logger import Logger @@ -58,9 +57,6 @@ class PPRSamplerOptions: hub nodes receive diminishing residual per neighbor, so capping the fetch has little effect on PPR accuracy while keeping per-hop RPC cost bounded. Set large to approximate fetching all neighbors. - total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults - to ``torch.int32``, which supports total degrees up to ~2 billion. - Use a larger dtype if nodes have exceptionally high aggregate degrees. max_fetch_iterations: Maximum number of iterations that issue RPC neighbor fetches. After this many fetch iterations, subsequent iterations push residuals using only already-cached neighbor lists (no new RPCs). @@ -73,7 +69,6 @@ class PPRSamplerOptions: eps: float = 1e-4 max_ppr_nodes: int = 50 num_neighbors_per_hop: int = 1_000 - total_degree_dtype: torch.dtype = torch.int32 max_fetch_iterations: Optional[int] = None diff --git a/gigl/distributed/utils/dist_sampler.py b/gigl/distributed/utils/dist_sampler.py index 0333f4138..db5dba1af 100644 --- a/gigl/distributed/utils/dist_sampler.py +++ b/gigl/distributed/utils/dist_sampler.py @@ -10,7 +10,7 @@ RemoteDistSamplingWorkerOptions, ) from graphlearn_torch.sampler import EdgeSamplerInput, NodeSamplerInput, SamplingConfig -from graphlearn_torch.typing import EdgeType +from graphlearn_torch.typing import NodeType from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler from gigl.distributed.dist_ppr_sampler import DistPPRNeighborSampler @@ -35,7 +35,7 @@ def create_dist_sampler( worker_options: Union[MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions], channel: ChannelBase, sampler_options: SamplerOptions, - degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + degree_tensors: Optional[dict[NodeType, torch.Tensor]], current_device: torch.device, ) -> SamplerRuntime: """Create a GiGL sampler runtime for one channel on one worker. @@ -84,7 +84,6 @@ def create_dist_sampler( max_ppr_nodes=sampler_options.max_ppr_nodes, max_fetch_iterations=sampler_options.max_fetch_iterations, num_neighbors_per_hop=sampler_options.num_neighbors_per_hop, - total_degree_dtype=sampler_options.total_degree_dtype, degree_tensors=degree_tensors, ) else: From 088fe1bfc5a93d98b25f51ffb3380feb2bd8ee48 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 19 May 2026 03:10:49 +0000 Subject: [PATCH 04/32] Improvements --- gigl/distributed/base_dist_loader.py | 26 +--- gigl/distributed/dist_dataset.py | 34 ++--- gigl/distributed/dist_ppr_sampler.py | 112 ++------------ .../shared_dist_sampling_producer.py | 44 +----- gigl/distributed/utils/degree.py | 139 +++++++++--------- 5 files changed, 103 insertions(+), 252 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 4e39273c5..496b32381 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -39,10 +39,6 @@ from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset -from gigl.distributed.dist_ppr_sampler import ( - build_ppr_node_type_to_edge_types, - build_ppr_total_degree_tensors, -) from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.compute import async_request_server from gigl.distributed.graph_store.dist_server import DistServer @@ -429,27 +425,7 @@ def create_mp_producer( """ channel = BaseDistLoader.create_colocated_channel(worker_options) if isinstance(sampler_options, PPRSamplerOptions): - assert dataset.graph is not None, ( - "DistDataset.graph must be set for PPR sampling" - ) - raw_degree_tensors = dataset.degree_tensor - is_homogeneous = not isinstance(dataset.graph, dict) - edge_types = ( - list(dataset.graph.keys()) if isinstance(dataset.graph, dict) else [] - ) - node_type_to_edge_types = build_ppr_node_type_to_edge_types( - is_homogeneous=is_homogeneous, - edge_types=edge_types, - edge_dir=dataset.edge_dir, - ) - degree_tensors = build_ppr_total_degree_tensors( - degree_tensors=raw_degree_tensors, - node_type_to_edge_types=node_type_to_edge_types, - ) - logger.info( - f"Pre-computed total degree tensors for PPR sampling across " - f"{len(degree_tensors)} node types." - ) + degree_tensors = dataset.degree_tensor else: degree_tensors = None return DistSamplingProducer( diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index cd38c5653..c0cf6f207 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -80,9 +80,7 @@ def __init__( edge_feature_info: Optional[ Union[FeatureInfo, dict[EdgeType, FeatureInfo]] ] = None, - degree_tensor: Optional[ - Union[torch.Tensor, dict[EdgeType, torch.Tensor]] - ] = None, + degree_tensor: Optional[dict[NodeType, torch.Tensor]] = None, max_labels_per_anchor_node: Optional[int] = None, ) -> None: """ @@ -108,7 +106,7 @@ def __init__( Note this will be None in the homogeneous case if the data has no node features, or will only contain node types with node features in the heterogeneous case. edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Dimension of edge features and its data type, will be a dict if heterogeneous. Note this will be None in the homogeneous case if the data has no edge features, or will only contain edge types with edge features in the heterogeneous case. - degree_tensor: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property. + degree_tensor: Optional[dict[NodeType, torch.Tensor]]: Pre-computed degree tensor keyed by node type. Lazily computed on first access via the degree_tensor property. max_labels_per_anchor_node (Optional[int]): Optional cap for how many labels to materialize per anchor node for ABLP label fetching. """ @@ -146,9 +144,7 @@ def __init__( self._node_feature_info = node_feature_info self._edge_feature_info = edge_feature_info - self._degree_tensor: Optional[ - Union[torch.Tensor, dict[EdgeType, torch.Tensor]] - ] = degree_tensor + self._degree_tensor: Optional[dict[NodeType, torch.Tensor]] = degree_tensor self._max_labels_per_anchor_node = max_labels_per_anchor_node # TODO (mkolodner-sc): Modify so that we don't need to rely on GLT's base variable naming (i.e. partition_idx, num_partitions) in favor of more clear @@ -307,13 +303,15 @@ def edge_feature_info( @property def degree_tensor( self, - ) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: + ) -> dict[NodeType, torch.Tensor]: """ - Lazily compute and return the degree tensor for the graph. + Lazily compute and return the total degree tensor per node type. On first access, computes node degrees from the graph partition and uses - all-reduce to aggregate across all machines. Requires torch.distributed - to be initialized. + all-reduce to aggregate across all machines. Degrees are summed across + all incident edge types per anchor node type before the all-reduce, so + the per-edge-type tensor is never stored. Requires torch.distributed to + be initialized. Over-counting correction (for processes sharing the same data on the same machine) is handled automatically by detecting the distributed topology. @@ -321,9 +319,9 @@ def degree_tensor( The result is cached for subsequent accesses. Returns: - Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: The aggregated degree tensor. - - For homogeneous graphs: A tensor of shape [num_nodes]. - - For heterogeneous graphs: A dict mapping EdgeType to degree tensors. + dict[NodeType, torch.Tensor]: Total degree tensors keyed by node type. + For homogeneous graphs the single entry uses + ``DEFAULT_HOMOGENEOUS_NODE_TYPE`` as its key. Raises: RuntimeError: If torch.distributed is not initialized. @@ -333,7 +331,9 @@ def degree_tensor( if self.graph is None: raise ValueError("Dataset graph is None. Cannot compute degrees.") - self._degree_tensor = compute_and_broadcast_degree_tensor(self.graph) + self._degree_tensor = compute_and_broadcast_degree_tensor( + self.graph, self._edge_dir + ) return self._degree_tensor @property @@ -902,7 +902,7 @@ def share_ipc( Optional[Union[int, dict[NodeType, int]]]: Number of test nodes on the current machine. Will be a dict if heterogeneous. Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]: Node feature dim and its data type, will be a dict if heterogeneous Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Edge feature dim and its data type, will be a dict if heterogeneous - Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous + Optional[dict[NodeType, torch.Tensor]]: Degree tensors keyed by node type Optional[int]: Optional per-anchor label cap for ABLP label fetching """ # TODO (mkolodner-sc): Investigate moving share_memory calls to the build() function @@ -1188,7 +1188,7 @@ def _rebuild_distributed_dataset( Optional[ Union[FeatureInfo, dict[EdgeType, FeatureInfo]] ], # Edge feature dim and its data type - Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], # Degree tensors + Optional[dict[NodeType, torch.Tensor]], # Degree tensors Optional[int], # Optional per-anchor label cap for ABLP label fetching ], ): diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index c6120cffa..69ea230f5 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -17,7 +17,7 @@ from graphlearn_torch.utils import merge_dict from gigl.distributed.base_sampler import BaseDistNeighborSampler -from gigl.types.graph import is_label_edge_type +from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, is_label_edge_type # Trailing "." is an intentional separator. These constants are used both to # write metadata keys (f"{KEY}{repr(edge_type)}" → e.g. "ppr_edge_index.('user', 'to', 'story')") @@ -26,103 +26,17 @@ PPR_EDGE_INDEX_METADATA_KEY = "ppr_edge_index." PPR_WEIGHT_METADATA_KEY = "ppr_weight." -# Sentinel type names for homogeneous graphs. The PPR algorithm uses -# dict[NodeType, ...] internally for both homo and hetero graphs; these -# sentinels let the homogeneous path reuse the same dict-based code. -_PPR_HOMOGENEOUS_NODE_TYPE = "ppr_homogeneous_node_type" +# Sentinel edge type for homogeneous graphs. The PPR algorithm uses +# dict[NodeType, ...] internally for both homo and hetero graphs; the +# DEFAULT_HOMOGENEOUS_NODE_TYPE sentinel lets the homogeneous path reuse +# the same dict-based code. _PPR_HOMOGENEOUS_EDGE_TYPE = ( - _PPR_HOMOGENEOUS_NODE_TYPE, + DEFAULT_HOMOGENEOUS_NODE_TYPE, "to", - _PPR_HOMOGENEOUS_NODE_TYPE, + DEFAULT_HOMOGENEOUS_NODE_TYPE, ) -def build_ppr_node_type_to_edge_types( - is_homogeneous: bool, - edge_types: list[EdgeType], - edge_dir: str, -) -> dict[NodeType, list[EdgeType]]: - """Build the node-type → edge-types mapping used by the PPR forward-push kernel. - - For homogeneous graphs returns the singleton sentinel mapping. For - heterogeneous graphs, groups non-label edge types by their anchor node type - (destination for ``edge_dir="in"``, source for ``edge_dir="out"``). - - Args: - is_homogeneous: True if the graph has a single node/edge type. - edge_types: All edge types present in the graph (ignored when homogeneous). - edge_dir: Sampling direction — ``"in"`` or ``"out"``. - - Returns: - Dict mapping each anchor NodeType to the list of EdgeTypes traversable - from it during a PPR walk. - """ - if is_homogeneous: - return {_PPR_HOMOGENEOUS_NODE_TYPE: [_PPR_HOMOGENEOUS_EDGE_TYPE]} - - node_type_to_edge_types: dict[NodeType, list[EdgeType]] = defaultdict(list) - for etype in edge_types: - if is_label_edge_type(etype): - continue - anchor_type = etype[-1] if edge_dir == "in" else etype[0] - node_type_to_edge_types[anchor_type].append(etype) - return dict(node_type_to_edge_types) - - -def build_ppr_total_degree_tensors( - degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], - node_type_to_edge_types: dict[NodeType, list[EdgeType]], -) -> dict[NodeType, torch.Tensor]: - """Pre-compute total-degree tensors for the PPR forward-push kernel. - - For homogeneous graphs converts the single degree tensor to int16. - For heterogeneous graphs sums per-edge-type degrees into a per-node-type - total (capped at int16 max), padding shorter tensors with zeros where node - counts differ. - - This function is intentionally standalone so it can be called once in the - parent process (and the result shared across workers) rather than redundantly - inside each worker's ``DistPPRNeighborSampler.__init__``. - - Args: - degree_tensors: Per-edge-type degree tensors (homogeneous: single - ``torch.Tensor``; heterogeneous: ``dict[EdgeType, torch.Tensor]``). - node_type_to_edge_types: Mapping from anchor NodeType to the list of - EdgeTypes traversable from it, as returned by - :func:`build_ppr_node_type_to_edge_types`. - - Returns: - Dict mapping NodeType to a 1-D total-degree tensor of shape - ``[num_nodes_of_that_type]`` with dtype ``torch.int16``, capped at - ``torch.iinfo(torch.int16).max``. - - Raises: - ValueError: If a required edge type is missing from ``degree_tensors``. - """ - _INT16_MAX = torch.iinfo(torch.int16).max - result: dict[NodeType, torch.Tensor] = {} - - if isinstance(degree_tensors, torch.Tensor): - result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(torch.int16) - else: - for node_type, edge_types in node_type_to_edge_types.items(): - max_len = 0 - for et in edge_types: - if et not in degree_tensors: - raise ValueError( - f"Edge type {et} not found in degree tensors. " - f"Available: {list(degree_tensors.keys())}" - ) - max_len = max(max_len, len(degree_tensors[et])) - summed = torch.zeros(max_len, dtype=torch.int64) - for et in edge_types: - et_degrees = degree_tensors[et] - summed[: len(et_degrees)] += et_degrees.to(torch.int64) - result[node_type] = summed.clamp(max=_INT16_MAX).to(torch.int16) - - return result - - class DistPPRNeighborSampler(BaseDistNeighborSampler): """Personalized PageRank (PPR) based distributed neighbor sampler. @@ -210,7 +124,7 @@ def __init__( self._node_type_to_edge_types[anchor_type].append(etype) else: - self._node_type_to_edge_types[_PPR_HOMOGENEOUS_NODE_TYPE] = [ + self._node_type_to_edge_types[DEFAULT_HOMOGENEOUS_NODE_TYPE] = [ _PPR_HOMOGENEOUS_EDGE_TYPE ] self._is_homogeneous = True @@ -389,7 +303,7 @@ async def _compute_ppr_scores( valid_counts = tensor([1, 3, 2, 0]) """ if seed_node_type is None: - seed_node_type = _PPR_HOMOGENEOUS_NODE_TYPE + seed_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE device = seed_nodes.device ppr_state = PPRForwardPush( @@ -449,12 +363,12 @@ async def _compute_ppr_scores( if self._is_homogeneous: assert ( len(ntype_to_flat_ids) == 1 - and _PPR_HOMOGENEOUS_NODE_TYPE in ntype_to_flat_ids + and DEFAULT_HOMOGENEOUS_NODE_TYPE in ntype_to_flat_ids ) return ( - ntype_to_flat_ids[_PPR_HOMOGENEOUS_NODE_TYPE], - ntype_to_flat_weights[_PPR_HOMOGENEOUS_NODE_TYPE], - ntype_to_valid_counts[_PPR_HOMOGENEOUS_NODE_TYPE], + ntype_to_flat_ids[DEFAULT_HOMOGENEOUS_NODE_TYPE], + ntype_to_flat_weights[DEFAULT_HOMOGENEOUS_NODE_TYPE], + ntype_to_valid_counts[DEFAULT_HOMOGENEOUS_NODE_TYPE], ) else: return ( diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index 6712ac850..b45f8deae 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -93,15 +93,11 @@ SamplingConfig, SamplingType, ) -from graphlearn_torch.typing import EdgeType, NodeType +from graphlearn_torch.typing import NodeType from torch._C import _set_worker_signal_handlers from gigl.common.logger import Logger -from gigl.distributed.dist_ppr_sampler import ( - build_ppr_node_type_to_edge_types, - build_ppr_total_degree_tensors, -) -from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils.dist_sampler import ( SamplerInput, SamplerRuntime, @@ -840,7 +836,7 @@ def __init__( worker_options: RemoteDistSamplingWorkerOptions, sampling_config: SamplingConfig, sampler_options: SamplerOptions, - degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + degree_tensors: Optional[dict[NodeType, torch.Tensor]], ) -> None: """Initialize the shared sampling backend. @@ -876,37 +872,9 @@ def __init__( self._completed_workers: defaultdict[tuple[int, int], set[int]] = defaultdict( set ) - # For PPR sampling, pre-compute the total-degree dict (summed across edge - # types, converted to the target dtype) once here in the parent process. - # Workers receive the result directly as degree_tensors and skip the - # per-worker summation in DistPPRNeighborSampler._build_total_degree_tensors. - # - # Then move to shared memory so all spawned workers map the same - # allocation instead of each pickling a private copy. In colocated mode - # DistDataset.to_ipc_handle() handles shared memory; here the tensors - # arrive via RPC and are plain heap allocations without this call. - if ( - isinstance(sampler_options, PPRSamplerOptions) - and degree_tensors is not None - ): - assert data.graph is not None, ( - "DistDataset.graph must be set for PPR sampling" - ) - is_homogeneous = not isinstance(data.graph, dict) - edge_types = list(data.graph.keys()) if isinstance(data.graph, dict) else [] - node_type_to_edge_types = build_ppr_node_type_to_edge_types( - is_homogeneous=is_homogeneous, - edge_types=edge_types, - edge_dir=data.edge_dir, - ) - self._degree_tensors: Optional[dict[NodeType, torch.Tensor]] = ( - build_ppr_total_degree_tensors( - degree_tensors=degree_tensors, - node_type_to_edge_types=node_type_to_edge_types, - ) - ) - else: - self._degree_tensors = None + # Move degree tensors to shared memory so all spawned workers map the + # same allocation instead of each pickling a private copy. + self._degree_tensors: Optional[dict[NodeType, torch.Tensor]] = degree_tensors share_memory(self._degree_tensors) def init_backend(self) -> None: diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index 7374f53ed..eab3e7ec3 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -5,8 +5,9 @@ and aggregate them across distributed machines. Degrees are computed from the CSR (Compressed Sparse Row) topology stored in GraphLearn-Torch Graph objects. -Note: Degree tensors are not moved to shared memory and may be duplicated across -processes on the same machine. +Degrees are accumulated per anchor node type (summing across all edge types +incident to that node type) before the distributed all-reduce, so callers +receive ``dict[NodeType, torch.Tensor]`` directly with no further conversion. Requirements ============ @@ -27,24 +28,28 @@ import torch from graphlearn_torch.data import Graph +from graphlearn_torch.typing import NodeType from torch_geometric.typing import EdgeType from gigl.common.logger import Logger from gigl.distributed.utils.device import get_device_from_process_group from gigl.distributed.utils.networking import get_internal_ip_from_all_ranks -from gigl.types.graph import is_label_edge_type +from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, is_label_edge_type logger = Logger() def compute_and_broadcast_degree_tensor( graph: Union[Graph, dict[EdgeType, Graph]], -) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: - """ - Compute node degrees from a graph and aggregate across all machines. + edge_dir: str, +) -> dict[NodeType, torch.Tensor]: + """Compute node degrees from a graph and aggregate across all machines. - Computes degrees from the CSR row pointers (indptr) and performs all-reduce - to aggregate across ranks. + For each non-label edge type, degrees are derived from the CSR row pointers + (indptr). For heterogeneous graphs, degrees are summed across all edge types + incident to each anchor node type **locally** before the all-reduce, so the + per-edge-type tensor is only a transient intermediate and is never stored, + returned, or transmitted over RPC. Over-counting correction (for processes sharing the same data) is handled automatically by detecting the distributed topology. @@ -52,13 +57,17 @@ def compute_and_broadcast_degree_tensor( Args: graph: A Graph (homogeneous) or dict[EdgeType, Graph] (heterogeneous). For heterogeneous graphs, label edge types are automatically excluded - from the computation — they are supervision edges and should not - contribute to node degree for graph traversal algorithms like PPR. + — they are supervision edges and should not contribute to node degree + for graph traversal algorithms like PPR. + edge_dir: Sampling direction — ``"in"`` or ``"out"``. Determines which + end of each edge is the anchor node type for degree accumulation. Returns: - Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: The aggregated degree tensors. - - For homogeneous graphs: A tensor of shape [num_nodes]. - - For heterogeneous graphs: A dict mapping non-label EdgeType to degree tensors. + dict[NodeType, torch.Tensor]: Aggregated degree tensors keyed by node + type. For homogeneous graphs the single entry uses + ``DEFAULT_HOMOGENEOUS_NODE_TYPE`` as its key. Values are int16 + tensors of shape ``[num_nodes_of_that_type]``, capped at + ``torch.iinfo(torch.int16).max``. Raises: RuntimeError: If torch.distributed is not initialized. @@ -69,52 +78,51 @@ def compute_and_broadcast_degree_tensor( "compute_and_broadcast_degree_tensor requires torch.distributed to be initialized." ) - # Compute local degrees from graph topology + local_dict: dict[NodeType, torch.Tensor] = {} + if isinstance(graph, Graph): topo = graph.topo if topo is None or topo.indptr is None: raise ValueError("Topology/indptr not available for graph.") - local_degrees: Union[torch.Tensor, dict[EdgeType, torch.Tensor]] = ( - _compute_degrees_from_indptr(topo.indptr) + local_dict[DEFAULT_HOMOGENEOUS_NODE_TYPE] = _compute_degrees_from_indptr( + topo.indptr ) else: - local_dict: dict[EdgeType, torch.Tensor] = {} for edge_type, edge_graph in graph.items(): - # Label edge types are supervision edges and should not contribute - # to node degree for graph traversal algorithms like PPR. if is_label_edge_type(edge_type): continue + anchor_type: NodeType = edge_type[-1] if edge_dir == "in" else edge_type[0] topo = edge_graph.topo if topo is None or topo.indptr is None: logger.warning( f"Topology/indptr not available for edge type {edge_type}, using empty tensor." ) - local_dict[edge_type] = torch.empty(0, dtype=torch.int16) + degrees = torch.empty(0, dtype=torch.int16) + else: + degrees = _compute_degrees_from_indptr(topo.indptr) + + if anchor_type in local_dict: + # Accumulate in int64 to avoid overflow, clamp back to int16 + existing = local_dict[anchor_type] + max_len = max(len(existing), len(degrees)) + summed = _pad_to_size(existing, max_len).to(torch.int64) + summed[: len(degrees)] += degrees.to(torch.int64) + local_dict[anchor_type] = _clamp_to_int16(summed) else: - local_dict[edge_type] = _compute_degrees_from_indptr(topo.indptr) - local_degrees = local_dict + local_dict[anchor_type] = degrees - # All-reduce across ranks (over-counting correction handled internally) - result = _all_reduce_degrees(local_degrees) + result = _all_reduce_degrees(local_dict) - # Log results - if isinstance(result, torch.Tensor): - if result.numel() > 0: + for node_type, degrees in result.items(): + if degrees.numel() > 0: logger.info( - f"{result.size(0)} nodes, max={result.max().item()}, min={result.min().item()}" + f"{node_type}: {degrees.size(0)} nodes, " + f"max={degrees.max().item()}, min={degrees.min().item()}" ) else: - logger.info("Graph contained 0 nodes when computing degrees") - else: - for edge_type, degrees in result.items(): - if degrees.numel() > 0: - logger.info( - f"{edge_type}: {degrees.size(0)} nodes, max={degrees.max().item()}, min={degrees.min().item()}" - ) - else: - logger.info( - f"Graph contained 0 nodes for edge type {edge_type} when computing degrees" - ) + logger.info( + f"Graph contained 0 nodes for node type {node_type} when computing degrees" + ) return result @@ -143,21 +151,19 @@ def _compute_degrees_from_indptr(indptr: torch.Tensor) -> torch.Tensor: def _all_reduce_degrees( - local_degrees: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], -) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: - """All-reduce degree tensors across ranks, handling both homogeneous and heterogeneous cases. + local_degrees: dict[NodeType, torch.Tensor], +) -> dict[NodeType, torch.Tensor]: + """All-reduce degree tensors across ranks. - For heterogeneous graphs, iterates over the edge types in local_degrees. All partitions - are expected to have entries for all edge types (even if some have empty tensors). - - Moves tensors to GPU for the all-reduce if using NCCL backend (which requires CUDA), - otherwise keeps tensors on CPU (for Gloo backend). + Moves tensors to GPU for the all-reduce if using NCCL backend (which + requires CUDA), otherwise keeps tensors on CPU (for Gloo backend). Over-counting correction: - In distributed training, multiple processes on the same machine often share the - same graph partition data (via shared memory). When we all-reduce degrees, each - process contributes its "local" degrees - but if 4 processes on one machine all - read the same partition, that partition's degrees get summed 4 times instead of 1. + In distributed training, multiple processes on the same machine often + share the same graph partition data (via shared memory). When we + all-reduce degrees, each process contributes its "local" degrees — but + if 4 processes on one machine all read the same partition, that + partition's degrees get summed 4 times instead of 1. Example: Machine A has 2 processes sharing partition with degrees [3, 5, 2]. Machine B has 2 processes sharing partition with degrees [1, 4, 6]. @@ -168,16 +174,16 @@ def _all_reduce_degrees( With correction: divide by local_world_size (2 per machine) = [4, 9, 8] (correct: [3+1, 5+4, 2+6]) - This function detects how many processes share the same machine by comparing - IP addresses, then divides by that count to correct the over-counting. + This function detects how many processes share the same machine by + comparing IP addresses, then divides by that count to correct the + over-counting. Args: - local_degrees: Either a single tensor (homogeneous) or dict mapping EdgeType - to tensors (heterogeneous). For heterogeneous graphs, all partitions must - have entries for all edge types. + local_degrees: Dict mapping NodeType to local degree tensors. + All partitions must have entries for all node types. Returns: - Aggregated degree tensors in the same format as input. + Aggregated degree tensors keyed by NodeType. Raises: RuntimeError: If torch.distributed is not initialized. @@ -187,38 +193,25 @@ def _all_reduce_degrees( "_all_reduce_degrees requires torch.distributed to be initialized." ) - # Compute local_world_size: number of processes on the same machine sharing data all_ips = get_internal_ip_from_all_ranks() my_rank = torch.distributed.get_rank() my_ip = all_ips[my_rank] local_world_size = Counter(all_ips)[my_ip] - # NCCL backend requires CUDA tensors; Gloo works with CPU device = get_device_from_process_group() def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor: """All-reduce a single tensor with size sync and over-counting correction.""" - # Synchronize max size across all ranks local_size = torch.tensor([tensor.size(0)], dtype=torch.long, device=device) torch.distributed.all_reduce(local_size, op=torch.distributed.ReduceOp.MAX) max_size = int(local_size.item()) - # Pad, convert to int64 (all_reduce doesn't support int16), move to device padded = _pad_to_size(tensor, max_size).to(torch.int64).to(device) torch.distributed.all_reduce(padded, op=torch.distributed.ReduceOp.SUM) - # Correct for over-counting, move back to CPU, and clamp to int16 - # TODO (mkolodner-sc): Potentially want to paramaterize this in the future if we want degrees higher than the int16 max. return _clamp_to_int16((padded // local_world_size).cpu()) - # Homogeneous case - if isinstance(local_degrees, torch.Tensor): - return reduce_tensor(local_degrees) - - # Heterogeneous case: all-reduce each edge type - # Sort edge types for deterministic ordering across ranks - result: dict[EdgeType, torch.Tensor] = {} - for edge_type in sorted(local_degrees.keys()): - result[edge_type] = reduce_tensor(local_degrees[edge_type]) - + result: dict[NodeType, torch.Tensor] = {} + for node_type in sorted(local_degrees.keys()): + result[node_type] = reduce_tensor(local_degrees[node_type]) return result From 5ca621ccb93f7ffaeb3b8c9b1a071180d3124329 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 19 May 2026 22:55:42 +0000 Subject: [PATCH 05/32] Change int16 to int32 --- gigl/distributed/dist_ppr_sampler.py | 6 +++--- gigl/distributed/utils/degree.py | 20 ++++++-------------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 69ea230f5..dc9671974 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -74,8 +74,8 @@ class DistPPRNeighborSampler(BaseDistNeighborSampler): but require more computation. Typical values: 1e-4 to 1e-6. max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores. num_neighbors_per_hop: Maximum number of neighbors to fetch per hop. - degree_tensors: Pre-computed total-degree tensors (int16, capped at - int16 max), keyed by NodeType. Must be pre-computed by the caller + degree_tensors: Pre-computed total-degree tensors (int32), keyed by NodeType. + Must be pre-computed by the caller (e.g. via :func:`build_ppr_total_degree_tensors`) so that workers share a single allocation rather than recomputing per-worker. """ @@ -180,7 +180,7 @@ def __init__( # Degree tensors indexed by ntype_id. Destination-only types get an empty # tensor; the C++ kernel returns 0 for those, matching _get_total_degree. self._degree_tensors_for_cpp: list[torch.Tensor] = [ - self._node_type_to_total_degree.get(nt, torch.zeros(0, dtype=torch.int16)) + self._node_type_to_total_degree.get(nt, torch.zeros(0, dtype=torch.int32)) for nt in all_node_types ] diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index eab3e7ec3..d33ec74f0 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -65,9 +65,8 @@ def compute_and_broadcast_degree_tensor( Returns: dict[NodeType, torch.Tensor]: Aggregated degree tensors keyed by node type. For homogeneous graphs the single entry uses - ``DEFAULT_HOMOGENEOUS_NODE_TYPE`` as its key. Values are int16 - tensors of shape ``[num_nodes_of_that_type]``, capped at - ``torch.iinfo(torch.int16).max``. + ``DEFAULT_HOMOGENEOUS_NODE_TYPE`` as its key. Values are int32 + tensors of shape ``[num_nodes_of_that_type]``. Raises: RuntimeError: If torch.distributed is not initialized. @@ -97,17 +96,16 @@ def compute_and_broadcast_degree_tensor( logger.warning( f"Topology/indptr not available for edge type {edge_type}, using empty tensor." ) - degrees = torch.empty(0, dtype=torch.int16) + degrees = torch.empty(0, dtype=torch.int32) else: degrees = _compute_degrees_from_indptr(topo.indptr) if anchor_type in local_dict: - # Accumulate in int64 to avoid overflow, clamp back to int16 existing = local_dict[anchor_type] max_len = max(len(existing), len(degrees)) summed = _pad_to_size(existing, max_len).to(torch.int64) summed[: len(degrees)] += degrees.to(torch.int64) - local_dict[anchor_type] = _clamp_to_int16(summed) + local_dict[anchor_type] = summed.to(torch.int32) else: local_dict[anchor_type] = degrees @@ -139,15 +137,9 @@ def _pad_to_size(tensor: torch.Tensor, target_size: int) -> torch.Tensor: return torch.cat([tensor, padding]) -def _clamp_to_int16(tensor: torch.Tensor) -> torch.Tensor: - """Clamp tensor values to int16 max and convert dtype.""" - max_int16 = torch.iinfo(torch.int16).max - return tensor.clamp(max=max_int16).to(torch.int16) - - def _compute_degrees_from_indptr(indptr: torch.Tensor) -> torch.Tensor: """Compute degrees from CSR row pointers: degree[i] = indptr[i+1] - indptr[i].""" - return (indptr[1:] - indptr[:-1]).to(torch.int16) + return (indptr[1:] - indptr[:-1]).to(torch.int32) def _all_reduce_degrees( @@ -209,7 +201,7 @@ def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor: padded = _pad_to_size(tensor, max_size).to(torch.int64).to(device) torch.distributed.all_reduce(padded, op=torch.distributed.ReduceOp.SUM) - return _clamp_to_int16((padded // local_world_size).cpu()) + return (padded // local_world_size).to(torch.int32).cpu() result: dict[NodeType, torch.Tensor] = {} for node_type in sorted(local_degrees.keys()): From ac2ef26f616f30282154fd4af82a639ead661326 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 28 May 2026 18:43:59 +0000 Subject: [PATCH 06/32] Fix degree tensor tests and type checks --- .../heterogeneous_inference.py | 5 +- gigl/common/metrics/decorators.py | 4 +- gigl/distributed/dist_dataset.py | 2 +- gigl/distributed/dist_ppr_sampler.py | 11 +- .../shared_dist_sampling_producer.py | 4 +- .../node_classification_modeling_task_spec.py | 3 +- .../dataset_input_metadata_translator_test.py | 24 ++- tests/unit/distributed/utils/degree_test.py | 160 ++++++++++++------ 8 files changed, 145 insertions(+), 68 deletions(-) diff --git a/examples/link_prediction/heterogeneous_inference.py b/examples/link_prediction/heterogeneous_inference.py index 9aeda018f..c2f047926 100644 --- a/examples/link_prediction/heterogeneous_inference.py +++ b/examples/link_prediction/heterogeneous_inference.py @@ -23,7 +23,7 @@ import gc import time from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Union, cast import torch import torch.distributed @@ -155,6 +155,9 @@ def _inference_process( assert isinstance(node_type_to_input_node_ids, dict), ( f"Node IDs must be a dictionary for heterogeneous inference, got {type(node_type_to_input_node_ids)}" ) + node_type_to_input_node_ids = cast( + dict[NodeType, torch.Tensor], node_type_to_input_node_ids + ) input_node_ids: torch.Tensor = node_type_to_input_node_ids[args.inference_node_type] data_loader = gigl.distributed.DistNeighborLoader( diff --git a/gigl/common/metrics/decorators.py b/gigl/common/metrics/decorators.py index c09ee50bf..6f84b8737 100644 --- a/gigl/common/metrics/decorators.py +++ b/gigl/common/metrics/decorators.py @@ -22,6 +22,7 @@ def __safely_flush_metrics( Callable[[], Optional[OpsMetricPublisher]] ], ) -> None: + metrics_instance = None if get_metrics_service_instance_fn is not None: metrics_instance = get_metrics_service_instance_fn() if metrics_instance is not None: @@ -45,8 +46,9 @@ def wrap(*args: Any, **kwargs: Any) -> Any: try: result = func(*args, **kwargs) except Exception as e: + func_name = getattr(func, "__name__", repr(func)) logger.info( - f"Exception raised, will flush metrics for: {func.__name__} and re-raise exception" + f"Exception raised, will flush metrics for: {func_name} and re-raise exception" ) logger.error(f"Exception: {e}") logger.error(traceback.format_exc()) diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index c0cf6f207..0cbe88301 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -879,7 +879,7 @@ def share_ipc( Optional[Union[int, dict[NodeType, int]]], Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]], Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]], - Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + Optional[dict[NodeType, torch.Tensor]], Optional[int], ]: """ diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 69ea230f5..f4e96388e 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -76,8 +76,8 @@ class DistPPRNeighborSampler(BaseDistNeighborSampler): num_neighbors_per_hop: Maximum number of neighbors to fetch per hop. degree_tensors: Pre-computed total-degree tensors (int16, capped at int16 max), keyed by NodeType. Must be pre-computed by the caller - (e.g. via :func:`build_ppr_total_degree_tensors`) so that workers - share a single allocation rather than recomputing per-worker. + through ``DistDataset.degree_tensor`` so that workers share a single + allocation rather than recomputing per-worker. """ def __init__( @@ -130,10 +130,9 @@ def __init__( self._is_homogeneous = True # Total-degree tensors keyed by NodeType, pre-computed by the caller. - # Callers (create_mp_producer for colocated, SharedDistSamplingBackend - # for graph-store) run build_ppr_total_degree_tensors once in the parent - # process and place the result in shared memory so all worker processes - # map the same allocation. + # Callers compute DistDataset.degree_tensor once in the parent process + # and place the result in shared memory so all worker processes map the + # same allocation. self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = degree_tensors # Build integer ID mappings for the C++ forward-push kernel. String diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index b45f8deae..c6564a39d 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -364,8 +364,8 @@ def _shared_sampling_worker_loop( sampler_options: GiGL sampler configuration (e.g. ``PPRSamplerOptions`` for PPR-based sampling). degree_tensors: Pre-computed degree tensors for PPR sampling, or - ``None`` for non-PPR samplers. Materialized once in the parent - process by ``_prepare_degree_tensors`` and shared across workers. + ``None`` for non-PPR samplers. Materialized once in the parent via + ``DistDataset.degree_tensor`` and shared across workers. Algorithm: 1. Initialize RPC, sampler infrastructure, and signal the parent via barrier. diff --git a/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py b/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py index bfaba1fb0..66809b5a2 100644 --- a/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py +++ b/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py @@ -201,8 +201,9 @@ def score( assert root_node_labels is not None results: InferBatchResults = self.infer_batch(batch=batch, device=device) + assert results.predictions is not None num_correct_in_batch = int( - (results.predictions == root_node_labels).sum() + torch.eq(results.predictions, root_node_labels).sum().item() ) # https://github.com/Snapchat/GiGL/issues/408 num_correct += num_correct_in_batch num_evaluated += len(batch.root_node_labels) diff --git a/tests/unit/distributed/dataset_input_metadata_translator_test.py b/tests/unit/distributed/dataset_input_metadata_translator_test.py index e5c5709b8..2899848f3 100644 --- a/tests/unit/distributed/dataset_input_metadata_translator_test.py +++ b/tests/unit/distributed/dataset_input_metadata_translator_test.py @@ -114,11 +114,17 @@ def test_translator_correctness(self, _, mocked_dataset_info: MockedDatasetInfo) ) if isinstance(serialized_graph_metadata.node_entity_info, abc.Mapping): - serialized_node_info_iterable = list( - serialized_graph_metadata.node_entity_info.values() + serialized_node_info_iterable = cast( + list[SerializedTFRecordInfo], + list(serialized_graph_metadata.node_entity_info.values()), ) else: - serialized_node_info_iterable = [serialized_graph_metadata.node_entity_info] + serialized_node_info_iterable = [ + cast( + SerializedTFRecordInfo, + serialized_graph_metadata.node_entity_info, + ) + ] self.assertEqual( len(graph_metadata_pb_wrapper.node_types), @@ -189,11 +195,17 @@ def test_translator_correctness(self, _, mocked_dataset_info: MockedDatasetInfo) ) if isinstance(serialized_graph_metadata.edge_entity_info, abc.Mapping): - serialized_edge_info_iterable = list( - serialized_graph_metadata.edge_entity_info.values() + serialized_edge_info_iterable = cast( + list[SerializedTFRecordInfo], + list(serialized_graph_metadata.edge_entity_info.values()), ) else: - serialized_edge_info_iterable = [serialized_graph_metadata.edge_entity_info] + serialized_edge_info_iterable = [ + cast( + SerializedTFRecordInfo, + serialized_graph_metadata.edge_entity_info, + ) + ] self.assertEqual( len(graph_metadata_pb_wrapper.edge_types), diff --git a/tests/unit/distributed/utils/degree_test.py b/tests/unit/distributed/utils/degree_test.py index ffcb6e5a4..780488472 100644 --- a/tests/unit/distributed/utils/degree_test.py +++ b/tests/unit/distributed/utils/degree_test.py @@ -1,3 +1,5 @@ +from typing import Literal + import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -10,6 +12,8 @@ _pad_to_size, compute_and_broadcast_degree_tensor, ) +from gigl.src.common.types.graph_data import EdgeType, NodeType +from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE from tests.test_assets.distributed.test_dataset import ( DEFAULT_HETEROGENEOUS_EDGE_INDICES, DEFAULT_HOMOGENEOUS_EDGE_INDEX, @@ -25,16 +29,58 @@ def _compute_expected_degrees_from_edge_index( - edge_index: torch.Tensor, num_nodes: int + edge_index: torch.Tensor, num_nodes: int, node_axis: int = 0 ) -> torch.Tensor: - """Compute expected out-degrees from COO edge index.""" - src_nodes = edge_index[0] + """Compute expected degrees from a COO edge index along one endpoint axis.""" + nodes = edge_index[node_axis] degrees = torch.zeros(num_nodes, dtype=torch.int16) - for src in src_nodes: - degrees[src] += 1 + for node in nodes: + degrees[node] += 1 return degrees +def _get_anchor_node_type( + edge_type: EdgeType, edge_dir: Literal["in", "out"] +) -> NodeType: + """Return the node type whose CSR rows define traversable degrees.""" + return edge_type.dst_node_type if edge_dir == "in" else edge_type.src_node_type + + +def _compute_expected_total_degrees_by_node_type( + edge_indices: dict[EdgeType, torch.Tensor], + edge_dir: Literal["in", "out"], +) -> dict[NodeType, torch.Tensor]: + """Compute total degrees keyed by anchor node type.""" + node_axis = 1 if edge_dir == "in" else 0 + expected: dict[NodeType, torch.Tensor] = {} + for edge_type, edge_index in edge_indices.items(): + anchor_node_type = _get_anchor_node_type(edge_type, edge_dir) + num_nodes = ( + int(edge_index[node_axis].max().item() + 1) + if edge_index.shape[1] > 0 + else 0 + ) + degrees = _compute_expected_degrees_from_edge_index( + edge_index=edge_index, + num_nodes=num_nodes, + node_axis=node_axis, + ) + + if anchor_node_type not in expected: + expected[anchor_node_type] = degrees + continue + + max_len = max(expected[anchor_node_type].numel(), degrees.numel()) + summed_degrees = torch.zeros(max_len, dtype=torch.int64) + summed_degrees[: expected[anchor_node_type].numel()] += expected[ + anchor_node_type + ].to(torch.int64) + summed_degrees[: degrees.numel()] += degrees.to(torch.int64) + expected[anchor_node_type] = _clamp_to_int16(summed_degrees) + + return expected + + class TestDegreeComputation(TestCase): """Tests for degree computation with torch.distributed initialized. @@ -60,12 +106,12 @@ def test_homogeneous_graph(self): dataset = create_homogeneous_dataset(edge_index=edge_index) assert dataset.graph is not None - result = compute_and_broadcast_degree_tensor(dataset.graph) + result = compute_and_broadcast_degree_tensor(dataset.graph, dataset.edge_dir) - assert isinstance(result, torch.Tensor) + self.assertEqual(set(result.keys()), {DEFAULT_HOMOGENEOUS_NODE_TYPE}) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assertEqual(result.shape[0], num_nodes) - self.assert_tensor_equality(result, expected) + self.assertEqual(result[DEFAULT_HOMOGENEOUS_NODE_TYPE].shape[0], num_nodes) + self.assert_tensor_equality(result[DEFAULT_HOMOGENEOUS_NODE_TYPE], expected) def test_heterogeneous_graph(self): """Test degree computation for a heterogeneous graph.""" @@ -73,15 +119,16 @@ def test_heterogeneous_graph(self): dataset = create_heterogeneous_dataset(edge_indices=edge_indices) assert dataset.graph is not None - result = compute_and_broadcast_degree_tensor(dataset.graph) + result = compute_and_broadcast_degree_tensor(dataset.graph, dataset.edge_dir) - assert isinstance(result, dict) - self.assertEqual(set(result.keys()), set(edge_indices.keys())) + expected = _compute_expected_total_degrees_by_node_type( + edge_indices=edge_indices, + edge_dir=dataset.edge_dir, + ) + self.assertEqual(set(result.keys()), set(expected.keys())) - for edge_type, edge_index in edge_indices.items(): - num_nodes = int(edge_index[0].max().item() + 1) - expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assert_tensor_equality(result[edge_type], expected) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + for node_type, expected_degrees in expected.items(): + self.assert_tensor_equality(result[node_type], expected_degrees) def test_heterogeneous_graph_with_missing_topology(self): """Test that edge types with missing topology get empty tensors. @@ -105,24 +152,37 @@ def test_heterogeneous_graph_with_missing_topology(self): # Save the original topology for computing expected degrees original_graph = dataset.graph[edge_type_with_topo] assert original_graph.topo is not None - expected_degrees = _compute_expected_degrees_from_edge_index( - edge_indices[edge_type_with_topo], - int(edge_indices[edge_type_with_topo][0].max().item() + 1), + expected_degrees = _compute_expected_total_degrees_by_node_type( + edge_indices={edge_type_with_topo: edge_indices[edge_type_with_topo]}, + edge_dir=dataset.edge_dir, ) # Manually set one graph's topology to None to test the edge case dataset.graph[edge_type_without_topo].topo = None - result = compute_and_broadcast_degree_tensor(dataset.graph) + result = compute_and_broadcast_degree_tensor(dataset.graph, dataset.edge_dir) - assert isinstance(result, dict) - self.assertEqual(set(result.keys()), set(edge_types)) + expected_node_types = { + _get_anchor_node_type(edge_type, dataset.edge_dir) + for edge_type in edge_types + } + self.assertEqual(set(result.keys()), expected_node_types) # Edge type with topology should have computed degrees - self.assert_tensor_equality(result[edge_type_with_topo], expected_degrees) + node_type_with_topo = _get_anchor_node_type( + edge_type=edge_type_with_topo, + edge_dir=dataset.edge_dir, + ) + self.assert_tensor_equality( + result[node_type_with_topo], expected_degrees[node_type_with_topo] + ) # Edge type without topology should have empty tensor - self.assertEqual(result[edge_type_without_topo].numel(), 0) + node_type_without_topo = _get_anchor_node_type( + edge_type=edge_type_without_topo, + edge_dir=dataset.edge_dir, + ) + self.assertEqual(result[node_type_without_topo].numel(), 0) def _run_local_world_size_correction_homogeneous( @@ -130,7 +190,7 @@ def _run_local_world_size_correction_homogeneous( world_size: int, init_method: str, edge_index: torch.Tensor, - expected_degrees: torch.Tensor, + expected_degrees: dict[NodeType, torch.Tensor], ) -> None: """Worker function for multi-process local_world_size correction test (homogeneous).""" dist.init_process_group( @@ -142,10 +202,11 @@ def _run_local_world_size_correction_homogeneous( try: dataset = create_homogeneous_dataset(edge_index=edge_index) assert dataset.graph is not None - result = compute_and_broadcast_degree_tensor(dataset.graph) + result = compute_and_broadcast_degree_tensor(dataset.graph, dataset.edge_dir) - assert isinstance(result, torch.Tensor) - assert_tensor_equality(result, expected_degrees) + assert set(result.keys()) == set(expected_degrees.keys()) + for node_type, expected in expected_degrees.items(): + assert_tensor_equality(result[node_type], expected) finally: dist.destroy_process_group() @@ -154,8 +215,8 @@ def _run_local_world_size_correction_heterogeneous( rank: int, world_size: int, init_method: str, - edge_indices: dict, - expected_degrees: dict, + edge_indices: dict[EdgeType, torch.Tensor], + expected_degrees: dict[NodeType, torch.Tensor], ) -> None: """Worker function for multi-process local_world_size correction test (heterogeneous).""" dist.init_process_group( @@ -167,12 +228,11 @@ def _run_local_world_size_correction_heterogeneous( try: dataset = create_heterogeneous_dataset(edge_indices=edge_indices) assert dataset.graph is not None - result = compute_and_broadcast_degree_tensor(dataset.graph) + result = compute_and_broadcast_degree_tensor(dataset.graph, dataset.edge_dir) - assert isinstance(result, dict) assert set(result.keys()) == set(expected_degrees.keys()) - for edge_type, expected in expected_degrees.items(): - assert_tensor_equality(result[edge_type], expected) + for node_type, expected in expected_degrees.items(): + assert_tensor_equality(result[node_type], expected) finally: dist.destroy_process_group() @@ -191,7 +251,9 @@ def test_local_world_size_correction_homogeneous(self): num_nodes = int(edge_index.max().item() + 1) raw_degrees = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - expected_degrees = raw_degrees # After correction: (2*raw) / 2 = raw + expected_degrees = { + DEFAULT_HOMOGENEOUS_NODE_TYPE: raw_degrees + } # After correction: (2*raw) / 2 = raw init_method = get_process_group_init_method() mp.spawn( @@ -204,13 +266,10 @@ def test_local_world_size_correction_heterogeneous(self): """Test over-counting correction for heterogeneous graphs with 2 processes.""" edge_indices = DEFAULT_HETEROGENEOUS_EDGE_INDICES - expected_degrees = {} - for edge_type, edge_index in edge_indices.items(): - num_nodes = int(edge_index[0].max().item() + 1) - raw_degrees = _compute_expected_degrees_from_edge_index( - edge_index, num_nodes - ) - expected_degrees[edge_type] = raw_degrees + expected_degrees = _compute_expected_total_degrees_by_node_type( + edge_indices=edge_indices, + edge_dir="out", + ) init_method = get_process_group_init_method() mp.spawn( @@ -242,9 +301,9 @@ def test_degree_tensor_homogeneous(self): dataset = create_homogeneous_dataset(edge_index=edge_index) result = dataset.degree_tensor - assert isinstance(result, torch.Tensor) + self.assertEqual(set(result.keys()), {DEFAULT_HOMOGENEOUS_NODE_TYPE}) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assert_tensor_equality(result, expected) + self.assert_tensor_equality(result[DEFAULT_HOMOGENEOUS_NODE_TYPE], expected) def test_degree_tensor_caches_result(self): """Test that degree_tensor property caches the result.""" @@ -262,13 +321,14 @@ def test_degree_tensor_heterogeneous(self): result = dataset.degree_tensor - assert isinstance(result, dict) - self.assertEqual(set(result.keys()), set(edge_indices.keys())) + expected = _compute_expected_total_degrees_by_node_type( + edge_indices=edge_indices, + edge_dir=dataset.edge_dir, + ) + self.assertEqual(set(result.keys()), set(expected.keys())) - for edge_type, edge_index in edge_indices.items(): - num_nodes = int(edge_index[0].max().item() + 1) - expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assert_tensor_equality(result[edge_type], expected) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + for node_type, expected_degrees in expected.items(): + self.assert_tensor_equality(result[node_type], expected_degrees) class TestHelperFunctions(TestCase): From d850b37bf6a099d619c8960b1966af8f76ff9805 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 28 May 2026 19:14:31 +0000 Subject: [PATCH 07/32] Add E2E PPR graphstore test --- Makefile | 8 ++ .../e2e_hom_cora_sup_gs_ppr_task_config.yaml | 79 +++++++++++++++++++ .../graph_store/homogeneous_inference.py | 18 ++++- .../graph_store/homogeneous_training.py | 15 +++- gigl/utils/sampling.py | 41 ++++++++++ tests/e2e_tests/e2e_tests.yaml | 3 + 6 files changed, 161 insertions(+), 3 deletions(-) create mode 100644 examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml diff --git a/Makefile b/Makefile index f80eb2952..e378fdd76 100644 --- a/Makefile +++ b/Makefile @@ -270,6 +270,14 @@ run_hom_cora_sup_gs_e2e_test: --test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \ --test_names="hom_cora_sup_gs_test" +run_hom_cora_sup_gs_ppr_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} +run_hom_cora_sup_gs_ppr_e2e_test: compile_gigl_kubeflow_pipeline +run_hom_cora_sup_gs_ppr_e2e_test: + uv run python tests/e2e_tests/e2e_test.py \ + --compiled_pipeline_path=$(compiled_pipeline_path) \ + --test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \ + --test_names="hom_cora_sup_gs_ppr_test" + run_het_dblp_sup_gs_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} run_het_dblp_sup_gs_e2e_test: compile_gigl_kubeflow_pipeline run_het_dblp_sup_gs_e2e_test: diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml new file mode 100644 index 000000000..1cff49a4c --- /dev/null +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml @@ -0,0 +1,79 @@ +# This config runs homogeneous CORA supervised training and inference in Graph Store mode +# with PPR sampling. It intentionally reuses the standard graph-store training/inference +# entrypoints, changing only the sampler args and keeping the loop short for E2E coverage. +graphMetadata: + edgeTypes: + - dstNodeType: paper + relation: cites + srcNodeType: paper + nodeTypes: + - paper +datasetConfig: + dataPreprocessorConfig: + dataPreprocessorConfigClsPath: gigl.src.mocking.mocking_assets.passthrough_preprocessor_config_for_mocked_assets.PassthroughPreprocessorConfigForMockedAssets + dataPreprocessorArgs: + mocked_dataset_name: 'cora_homogeneous_node_anchor_edge_features_user_defined_labels' +trainerConfig: + trainerArgs: + log_every_n_batch: "1" + num_neighbors: "[10, 10]" + sampler_type: "ppr" + ppr_alpha: "0.5" + ppr_eps: "0.0001" + ppr_max_nodes: "20" + ppr_neighbors_per_hop: "100" + ppr_max_fetch_iterations: "2" + local_world_size: "1" + sampling_workers_per_process: "1" + sampling_worker_shared_channel_size: "512MB" + main_batch_size: "8" + random_batch_size: "8" + num_max_train_batches: "2" + num_val_batches: "2" + val_every_n_batch: "1" + command: python -m examples.link_prediction.graph_store.homogeneous_training + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: >- + { + "sampling_direction": "in", + "should_convert_labels_to_edges": True, + "num_val": 0.25, + "num_test": 0.25 + } + num_server_sessions: "1" +inferencerConfig: + inferencerArgs: + log_every_n_batch: "1" + num_neighbors: "[10, 10]" + sampler_type: "ppr" + ppr_alpha: "0.5" + ppr_eps: "0.0001" + ppr_max_nodes: "20" + ppr_neighbors_per_hop: "100" + ppr_max_fetch_iterations: "2" + local_world_size: "1" + sampling_workers_per_inference_process: "1" + sampling_worker_shared_channel_size: "512MB" + inferenceBatchSize: 256 + command: python -m examples.link_prediction.graph_store.homogeneous_inference + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + num_server_sessions: "1" +sharedConfig: + shouldSkipInference: false + shouldSkipModelEvaluation: true +taskMetadata: + nodeAnchorBasedLinkPredictionTaskMetadata: + supervisionEdgeTypes: + - dstNodeType: paper + relation: cites + srcNodeType: paper +featureFlags: + should_run_glt_backend: 'True' + data_preprocessor_num_shards: '2' diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 34bc2672e..eac5c519e 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -87,7 +87,7 @@ import sys import time from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import torch import torch.multiprocessing as mp @@ -101,6 +101,7 @@ from gigl.common.utils.gcs import GcsUtils from gigl.distributed.graph_store.compute import init_compute_process from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils import get_graph_store_info from gigl.env.distributed import GraphStoreInfo from gigl.nn import LinkPredictionGNN @@ -110,7 +111,7 @@ from gigl.src.common.utils.bq import BqUtils from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.inference.lib.assets import InferenceAssets -from gigl.utils.sampling import parse_fanout +from gigl.utils.sampling import parse_fanout, parse_sampler_options logger = Logger() @@ -143,6 +144,7 @@ class InferenceProcessArgs: inference_batch_size (int): Batch size to use for inference. num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Fanout for subgraph sampling, where the ith item corresponds to the number of items to sample for the ith hop. + sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. sampling_workers_per_inference_process (int): Number of sampling workers per inference process. sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for @@ -169,6 +171,7 @@ class InferenceProcessArgs: # Inference configuration inference_batch_size: int num_neighbors: Union[list[int], dict[EdgeType, list[int]]] + sampler_options: Optional[SamplerOptions] sampling_workers_per_inference_process: int sampling_worker_shared_channel_size: str log_every_n_batch: int @@ -242,6 +245,7 @@ def _inference_process( # For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders # don't compete for memory during initialization, causing OOM process_start_gap_seconds=0, + sampler_options=args.sampler_options, ) # Initialize a LinkPredictionGNN model and load parameters from # the saved model. @@ -494,6 +498,7 @@ def _run_example_inference( # Parses the fanout as a string. For the homogeneous case, the fanouts should be specified # as a string of a list of integers, such as "[10, 10]". num_neighbors = parse_fanout(inferencer_args.get("num_neighbors", "[10, 10]")) + sampler_options = parse_sampler_options(inferencer_args) # While the ideal value for `sampling_workers_per_inference_process` has been identified to be # between `2` and `4`, this may need some tuning depending on the pipeline. We default this @@ -516,6 +521,14 @@ def _run_example_inference( log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50")) + logger.info( + f"Got inference args local_world_size={local_world_size}, " + f"num_neighbors={num_neighbors}, sampler_options={sampler_options}, " + f"sampling_workers_per_inference_process={sampling_workers_per_inference_process}, " + f"sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, " + f"log_every_n_batch={log_every_n_batch}" + ) + # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. inference_args = InferenceProcessArgs( local_world_size=local_world_size, @@ -528,6 +541,7 @@ def _run_example_inference( edge_feature_dim=edge_feature_dim, inference_batch_size=inference_batch_size, num_neighbors=num_neighbors, + sampler_options=sampler_options, sampling_workers_per_inference_process=sampling_workers_per_inference_process, sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, log_every_n_batch=log_every_n_batch, diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 04340f99a..3626f8566 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -143,6 +143,7 @@ shutdown_compute_process, ) from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils import get_available_device, get_graph_store_info from gigl.env.distributed import GraphStoreInfo from gigl.nn import LinkPredictionGNN, RetrievalLoss @@ -158,7 +159,7 @@ from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict from gigl.utils.iterator import InfiniteIterator -from gigl.utils.sampling import parse_fanout +from gigl.utils.sampling import parse_fanout, parse_sampler_options logger = Logger() @@ -191,6 +192,7 @@ def _setup_dataloaders( split: Literal["train", "val", "test"], cluster_info: GraphStoreInfo, num_neighbors: list[int] | dict[EdgeType, list[int]], + sampler_options: Optional[SamplerOptions], sampling_workers_per_process: int, main_batch_size: int, random_batch_size: int, @@ -205,6 +207,7 @@ def _setup_dataloaders( split (Literal["train", "val", "test"]): The current split which we are loading data for. cluster_info (GraphStoreInfo): Cluster topology info for graph store mode. num_neighbors: Fanout for subgraph sampling. + sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. sampling_workers_per_process (int): Number of sampling workers per training/testing process. main_batch_size (int): Batch size for main dataloader with query and labeled nodes. random_batch_size (int): Batch size for random negative dataloader. @@ -240,6 +243,7 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + sampler_options=sampler_options, ) logger.info(f"---Rank {rank} finished setting up main loader for split={split}") @@ -266,6 +270,7 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + sampler_options=sampler_options, ) logger.info( @@ -375,6 +380,7 @@ class TrainingProcessArgs: sampling_workers_per_process (int): Number of sampling workers per training/testing process. sampling_worker_shared_channel_size (str): Shared-memory buffer size for the channel during sampling. process_start_gap_seconds (int): Time to sleep between dataloader initializations. + sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. main_batch_size (int): Batch size for main dataloader. random_batch_size (int): Batch size for random negative dataloader. learning_rate (float): Learning rate for the optimizer. @@ -400,6 +406,7 @@ class TrainingProcessArgs: # Sampling config num_neighbors: list[int] | dict[EdgeType, list[int]] + sampler_options: Optional[SamplerOptions] sampling_workers_per_process: int sampling_worker_shared_channel_size: str process_start_gap_seconds: int @@ -463,6 +470,7 @@ def _training_process( split="train", cluster_info=args.cluster_info, num_neighbors=args.num_neighbors, + sampler_options=args.sampler_options, sampling_workers_per_process=args.sampling_workers_per_process, main_batch_size=args.main_batch_size, random_batch_size=args.random_batch_size, @@ -481,6 +489,7 @@ def _training_process( split="val", cluster_info=args.cluster_info, num_neighbors=args.num_neighbors, + sampler_options=args.sampler_options, sampling_workers_per_process=args.sampling_workers_per_process, main_batch_size=args.main_batch_size, random_batch_size=args.random_batch_size, @@ -637,6 +646,7 @@ def _training_process( split="test", cluster_info=args.cluster_info, num_neighbors=args.num_neighbors, + sampler_options=args.sampler_options, sampling_workers_per_process=args.sampling_workers_per_process, main_batch_size=args.main_batch_size, random_batch_size=args.random_batch_size, @@ -853,6 +863,7 @@ def _run_example_training( fanout = trainer_args.get("num_neighbors", "[10, 10]") num_neighbors = parse_fanout(fanout) + sampler_options = parse_sampler_options(trainer_args) sampling_workers_per_process: int = int( trainer_args.get("sampling_workers_per_process", "4") @@ -880,6 +891,7 @@ def _run_example_training( logger.info( f"Got training args local_world_size={local_world_size}, \ num_neighbors={num_neighbors}, \ + sampler_options={sampler_options}, \ sampling_workers_per_process={sampling_workers_per_process}, \ main_batch_size={main_batch_size}, \ random_batch_size={random_batch_size}, \ @@ -931,6 +943,7 @@ def _run_example_training( node_feature_dim=node_feature_dim, edge_feature_dim=edge_feature_dim, num_neighbors=num_neighbors, + sampler_options=sampler_options, sampling_workers_per_process=sampling_workers_per_process, sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, diff --git a/gigl/utils/sampling.py b/gigl/utils/sampling.py index 5d0ed6a44..e2c6996e5 100644 --- a/gigl/utils/sampling.py +++ b/gigl/utils/sampling.py @@ -1,10 +1,12 @@ import ast +from collections.abc import Mapping from dataclasses import dataclass from typing import Any, Optional, Union import torch from gigl.common.logger import Logger +from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.src.common.types.graph_data import EdgeType, NodeType logger = Logger() @@ -88,6 +90,45 @@ def parse_fanout(fanout_str: str) -> Union[list[int], dict[EdgeType, list[int]]] ) +def _parse_optional_int(value: Optional[str]) -> Optional[int]: + if value is None: + return None + normalized = value.strip().lower() + if normalized in {"", "none", "null"}: + return None + return int(value) + + +def parse_sampler_options(args: Mapping[str, str]) -> Optional[SamplerOptions]: + sampler_type = args.get("sampler_type", "khop").strip().lower().replace("-", "_") + if sampler_type == "": + sampler_type = "khop" + + if sampler_type in {"khop", "k_hop", "neighbor", "neighbor_sampler"}: + return None + + if sampler_type != "ppr": + raise ValueError( + f"Unsupported sampler_type={sampler_type}. Expected one of: khop, ppr." + ) + + max_ppr_nodes = args.get("ppr_max_nodes") + if max_ppr_nodes is None: + max_ppr_nodes = args.get("ppr_max_ppr_nodes", "50") + + num_neighbors_per_hop = args.get("ppr_neighbors_per_hop") + if num_neighbors_per_hop is None: + num_neighbors_per_hop = args.get("ppr_num_neighbors_per_hop", "1000") + + return PPRSamplerOptions( + alpha=float(args.get("ppr_alpha", "0.5")), + eps=float(args.get("ppr_eps", "0.0001")), + max_ppr_nodes=int(max_ppr_nodes), + num_neighbors_per_hop=int(num_neighbors_per_hop), + max_fetch_iterations=_parse_optional_int(args.get("ppr_max_fetch_iterations")), + ) + + @dataclass(frozen=True) class ABLPInputNodes: """Represents ABLP (Anchor Based Link Prediction) input for a single storage server. diff --git a/tests/e2e_tests/e2e_tests.yaml b/tests/e2e_tests/e2e_tests.yaml index 61fc4f311..6d09d8213 100644 --- a/tests/e2e_tests/e2e_tests.yaml +++ b/tests/e2e_tests/e2e_tests.yaml @@ -22,6 +22,9 @@ tests: hom_cora_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" + hom_cora_sup_gs_ppr_test: + task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" het_dblp_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" From 845704b89e39e3aa1b5028909b43c2fc3e28f1dc Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 28 May 2026 20:47:43 +0000 Subject: [PATCH 08/32] Update --- .../configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml index 1cff49a4c..c9e35eeef 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml @@ -23,13 +23,12 @@ trainerConfig: ppr_max_nodes: "20" ppr_neighbors_per_hop: "100" ppr_max_fetch_iterations: "2" - local_world_size: "1" sampling_workers_per_process: "1" sampling_worker_shared_channel_size: "512MB" main_batch_size: "8" random_batch_size: "8" - num_max_train_batches: "2" - num_val_batches: "2" + num_max_train_batches: "4" + num_val_batches: "4" val_every_n_batch: "1" command: python -m examples.link_prediction.graph_store.homogeneous_training graphStoreStorageConfig: @@ -55,7 +54,6 @@ inferencerConfig: ppr_max_nodes: "20" ppr_neighbors_per_hop: "100" ppr_max_fetch_iterations: "2" - local_world_size: "1" sampling_workers_per_inference_process: "1" sampling_worker_shared_channel_size: "512MB" inferenceBatchSize: 256 From ebbc318101702e945edc41357f4e3e4601a6a74b Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 28 May 2026 21:46:31 +0000 Subject: [PATCH 09/32] Fixes --- .../e2e_hom_cora_sup_gs_ppr_task_config.yaml | 2 - .../graph_store/homogeneous_inference.py | 40 ++++++++----------- .../graph_store/homogeneous_training.py | 14 ++++--- 3 files changed, 25 insertions(+), 31 deletions(-) diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml index c9e35eeef..878557cac 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml @@ -24,7 +24,6 @@ trainerConfig: ppr_neighbors_per_hop: "100" ppr_max_fetch_iterations: "2" sampling_workers_per_process: "1" - sampling_worker_shared_channel_size: "512MB" main_batch_size: "8" random_batch_size: "8" num_max_train_batches: "4" @@ -55,7 +54,6 @@ inferencerConfig: ppr_neighbors_per_hop: "100" ppr_max_fetch_iterations: "2" sampling_workers_per_inference_process: "1" - sampling_worker_shared_channel_size: "512MB" inferenceBatchSize: 256 command: python -m examples.link_prediction.graph_store.homogeneous_inference graphStoreStorageConfig: diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index eac5c519e..5faa84b72 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -115,12 +115,6 @@ logger = Logger() -# Default number of inference processes per machine incase one isnt provided in inference args -# i.e. `local_world_size` is not provided, and we can't infer automatically. -# If there are GPUs attached to the machine, we automatically infer to setting -# LOCAL_WORLD_SIZE == # of gpus on the machine. -DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE = 4 - @dataclass(frozen=True) class InferenceProcessArgs: @@ -459,25 +453,23 @@ def _run_example_inference( if arg_local_world_size is not None: local_world_size = int(arg_local_world_size) logger.info(f"Using local_world_size from inferencer_args: {local_world_size}") - if torch.cuda.is_available() and local_world_size != torch.cuda.device_count(): - logger.warning( - f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. " - "This may lead to unexpected failures with NCCL communication incase GPUs are being used for " - + "training/inference. Consider setting local_world_size to the number of GPUs." - ) else: - if torch.cuda.is_available() and torch.cuda.device_count() > 0: - # If GPUs are available, we set the local_world_size to the number of GPUs - local_world_size = torch.cuda.device_count() - logger.info( - f"Detected {local_world_size} GPUs. Thus, setting local_world_size to {local_world_size}" - ) - else: - # If no GPUs are available, we set the local_world_size to the number of inference processes per machine - logger.info( - f"No GPUs detected. Thus, setting local_world_size to `{DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE}`" - ) - local_world_size = DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE + local_world_size = cluster_info.num_processes_per_compute + logger.info( + f"Using local_world_size from cluster_info.num_processes_per_compute: {local_world_size}" + ) + if local_world_size != cluster_info.num_processes_per_compute: + raise ValueError( + f"Graph Store local_world_size={local_world_size} must match " + f"cluster_info.num_processes_per_compute=" + f"{cluster_info.num_processes_per_compute}" + ) + if torch.cuda.is_available() and local_world_size != torch.cuda.device_count(): + logger.warning( + f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. " + "This may lead to unexpected failures with NCCL communication incase GPUs are being used for " + + "training/inference. Consider setting local_world_size to the number of GPUs." + ) if cluster_info.compute_node_rank == 0: gcs_utils = GcsUtils() diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 3626f8566..c7ae356cc 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -847,13 +847,17 @@ def _run_example_training( # Training Hyperparameters trainer_args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args) - if torch.cuda.is_available(): - default_local_world_size = torch.cuda.device_count() - else: - default_local_world_size = 2 local_world_size = int( - trainer_args.get("local_world_size", str(default_local_world_size)) + trainer_args.get( + "local_world_size", str(cluster_info.num_processes_per_compute) + ) ) + if local_world_size != cluster_info.num_processes_per_compute: + raise ValueError( + f"Graph Store local_world_size={local_world_size} must match " + f"cluster_info.num_processes_per_compute=" + f"{cluster_info.num_processes_per_compute}" + ) if torch.cuda.is_available(): if local_world_size > torch.cuda.device_count(): From 65eac992092e5dea40bba0f561a3e181e3300233 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 28 May 2026 22:41:37 +0000 Subject: [PATCH 10/32] Fix PPR graph-store sampling worker capacity --- .../configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml index 878557cac..46c508819 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml @@ -23,7 +23,7 @@ trainerConfig: ppr_max_nodes: "20" ppr_neighbors_per_hop: "100" ppr_max_fetch_iterations: "2" - sampling_workers_per_process: "1" + sampling_workers_per_process: "2" main_batch_size: "8" random_batch_size: "8" num_max_train_batches: "4" @@ -53,7 +53,7 @@ inferencerConfig: ppr_max_nodes: "20" ppr_neighbors_per_hop: "100" ppr_max_fetch_iterations: "2" - sampling_workers_per_inference_process: "1" + sampling_workers_per_inference_process: "2" inferenceBatchSize: 256 command: python -m examples.link_prediction.graph_store.homogeneous_inference graphStoreStorageConfig: From 97bd538659feed9ec3ca345616e09c95e991795f Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 00:23:24 +0000 Subject: [PATCH 11/32] Fix --- gigl/distributed/dist_ppr_sampler.py | 8 +++ .../unit/distributed/dist_ppr_sampler_test.py | 54 +++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 32eb7531c..2afa84a23 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -576,6 +576,14 @@ async def _sample_from_nodes( ) else: + if isinstance(nodes_to_sample, dict): + node_types = set(nodes_to_sample.keys()) + if node_types != {DEFAULT_HOMOGENEOUS_NODE_TYPE}: + raise ValueError( + f"Expected only {DEFAULT_HOMOGENEOUS_NODE_TYPE} for homogeneous PPR sampling, " + f"received node types: {node_types}" + ) + nodes_to_sample = nodes_to_sample[DEFAULT_HOMOGENEOUS_NODE_TYPE] assert isinstance(nodes_to_sample, torch.Tensor) # Register seeds; local indices 0..N-1 are assigned internally. diff --git a/tests/unit/distributed/dist_ppr_sampler_test.py b/tests/unit/distributed/dist_ppr_sampler_test.py index 400ce1107..e24dd0470 100644 --- a/tests/unit/distributed/dist_ppr_sampler_test.py +++ b/tests/unit/distributed/dist_ppr_sampler_test.py @@ -41,6 +41,10 @@ from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader from gigl.distributed.distributed_neighborloader import DistNeighborLoader from gigl.distributed.sampler_options import PPRSamplerOptions +from gigl.types.graph import ( + DEFAULT_HOMOGENEOUS_EDGE_TYPE, + DEFAULT_HOMOGENEOUS_NODE_TYPE, +) from tests.test_assets.distributed.test_dataset import ( STORY, STORY_TO_USER, @@ -589,6 +593,52 @@ def _run_ppr_ablp_loader_correctness_check( shutdown_rpc() +def _run_ppr_labeled_homogeneous_ablp_loader_check(_: int) -> None: + """Verify PPR works for labeled homogeneous DistABLPLoader inputs.""" + create_test_process_group() + + dataset = create_heterogeneous_dataset_for_ablp( + positive_labels={0: [1], 1: [2], 2: [0]}, + negative_labels={0: [2], 1: [0], 2: [1]}, + train_node_ids=[0, 1], + val_node_ids=[2], + test_node_ids=[], + edge_indices={DEFAULT_HOMOGENEOUS_EDGE_TYPE: _TEST_EDGE_INDEX}, + src_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + dst_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + edge_dir="out", + ) + + train_node_ids = dataset.train_node_ids + assert isinstance(train_node_ids, dict) + + loader = DistABLPLoader( + dataset=dataset, + num_neighbors=[], + input_nodes=train_node_ids[DEFAULT_HOMOGENEOUS_NODE_TYPE], + sampler_options=PPRSamplerOptions( + alpha=_TEST_ALPHA, + eps=_TEST_EPS, + max_ppr_nodes=_TEST_MAX_PPR_NODES, + ), + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + + datum = next(iter(loader)) + assert isinstance(datum, Data) + assert hasattr(datum, "edge_index"), "Missing PPR edge_index on Data" + assert hasattr(datum, "edge_attr"), "Missing PPR edge_attr on Data" + assert hasattr(datum, "y_positive"), "Missing y_positive on Data" + assert hasattr(datum, "y_negative"), "Missing y_negative on Data" + assert datum.edge_index.dim() == 2 + assert datum.edge_index.size(0) == 2 + assert datum.edge_index.size(1) == datum.edge_attr.numel() + + shutdown_rpc() + + # --------------------------------------------------------------------------- # Bug regression runners # --------------------------------------------------------------------------- @@ -758,6 +808,10 @@ def test_ppr_sampler_ablp_ignores_label_edges_for_anchor_ppr(self) -> None: """Verify ABLP label edges are excluded from anchor-seed PPR walks.""" mp.spawn(fn=_run_ppr_ablp_label_edges_do_not_affect_anchor_ppr, args=()) + def test_ppr_sampler_homogeneous_ablp(self) -> None: + """Verify PPR handles homogeneous ABLP seed dictionaries.""" + mp.spawn(fn=_run_ppr_labeled_homogeneous_ablp_loader_check, args=()) + if __name__ == "__main__": absltest.main() From 92c9f515a43c0dc6668502914372536c1e2dbb99 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 00:37:41 +0000 Subject: [PATCH 12/32] more fixes --- gigl/distributed/base_sampler.py | 15 ++++++++++----- gigl/distributed/dist_ppr_sampler.py | 11 +++++++++-- gigl/distributed/utils/neighborloader.py | 12 ++++++++++++ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/gigl/distributed/base_sampler.py b/gigl/distributed/base_sampler.py index 986ba5d58..e8e6f9e77 100644 --- a/gigl/distributed/base_sampler.py +++ b/gigl/distributed/base_sampler.py @@ -1,3 +1,4 @@ +import logging from collections import defaultdict from dataclasses import dataclass from typing import Optional, Union @@ -213,11 +214,15 @@ async def _send_adapter( Copied from ``graphlearn_torch.distributed.DistNeighborSampler._send_adapter`` (GLT 0.2.4) with the single change of ``_colloate_fn`` → ``_collate_fn``. """ - sampler_output = await async_func(*args, **kwargs) - res = await self._collate_fn(sampler_output) - if self.channel is None: - return res - self.channel.send(res) + try: + sampler_output = await async_func(*args, **kwargs) + res = await self._collate_fn(sampler_output) + if self.channel is None: + return res + self.channel.send(res) + except Exception: + logging.exception("sampler task failed") + raise return None async def _collate_fn( diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 2afa84a23..3223fa0f3 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -234,8 +234,15 @@ async def _batch_fetch_neighbors( self._sample_one_hop( srcs=nodes_by_etype_id[eid].to(device), num_nbr=self._num_neighbors_per_hop, - # _sample_one_hop expects None for homogeneous graphs, not the PPR sentinel. - etype=None if etype == _PPR_HOMOGENEOUS_EDGE_TYPE else etype, + # _sample_one_hop expects None only for true homogeneous graphs. + # Labeled homogeneous ABLP graphs are hetero-backed because label + # edges are represented as separate edge types, so they still need + # the explicit default edge type here. + etype=( + None + if self._is_homogeneous and etype == _PPR_HOMOGENEOUS_EDGE_TYPE + else etype + ), ) ) outputs: list[NeighborOutput] = await asyncio.gather(*sample_tasks) diff --git a/gigl/distributed/utils/neighborloader.py b/gigl/distributed/utils/neighborloader.py index b91b411e3..570fca93b 100644 --- a/gigl/distributed/utils/neighborloader.py +++ b/gigl/distributed/utils/neighborloader.py @@ -357,6 +357,18 @@ def attach_ppr_outputs( f"PPR edge index and weight edge types must match, " f"got {set(ppr_edge_indices.keys())} vs {set(ppr_weights.keys())}" ) + if isinstance(data, Data): + if len(ppr_edge_indices) > 1: + raise ValueError( + "Expected at most one PPR edge type for homogeneous Data output, " + f"got {set(ppr_edge_indices.keys())}" + ) + if ppr_edge_indices: + edge_type = next(iter(ppr_edge_indices)) + data.edge_index = ppr_edge_indices[edge_type] + data.edge_attr = ppr_weights[edge_type] + return + for edge_type, edge_index in ppr_edge_indices.items(): data[edge_type].edge_index = edge_index data[edge_type].edge_attr = ppr_weights[edge_type] From 7e31417af7e3ccc508bd19bf5b80cb707ad8139f Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 00:49:58 +0000 Subject: [PATCH 13/32] change back --- gigl/distributed/dist_ppr_sampler.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 3223fa0f3..80049f305 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -583,25 +583,32 @@ async def _sample_from_nodes( ) else: - if isinstance(nodes_to_sample, dict): + if isinstance(nodes_to_sample, torch.Tensor): + homogeneous_nodes_to_sample = nodes_to_sample + elif isinstance(nodes_to_sample, dict): node_types = set(nodes_to_sample.keys()) if node_types != {DEFAULT_HOMOGENEOUS_NODE_TYPE}: raise ValueError( f"Expected only {DEFAULT_HOMOGENEOUS_NODE_TYPE} for homogeneous PPR sampling, " f"received node types: {node_types}" ) - nodes_to_sample = nodes_to_sample[DEFAULT_HOMOGENEOUS_NODE_TYPE] - assert isinstance(nodes_to_sample, torch.Tensor) + homogeneous_nodes_to_sample = nodes_to_sample[ + DEFAULT_HOMOGENEOUS_NODE_TYPE + ] + else: + raise TypeError( + f"Expected Tensor or node-type mapping for homogeneous PPR sampling, got {type(nodes_to_sample)}" + ) # Register seeds; local indices 0..N-1 are assigned internally. # srcs holds their global IDs (same values as nodes_to_sample). - srcs = inducer.init_node(nodes_to_sample) + srcs = inducer.init_node(homogeneous_nodes_to_sample) ( homo_flat_ids, homo_flat_weights, homo_valid_counts, - ) = await self._compute_ppr_scores(nodes_to_sample, None) + ) = await self._compute_ppr_scores(homogeneous_nodes_to_sample, None) assert isinstance(homo_flat_ids, torch.Tensor) assert isinstance(homo_flat_weights, torch.Tensor) assert isinstance(homo_valid_counts, torch.Tensor) From d9d2086c29b40bf704e185aca98321bbec583c5c Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 16:48:51 +0000 Subject: [PATCH 14/32] Avoid cast for heterogeneous inference node ids --- .../link_prediction/heterogeneous_inference.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/link_prediction/heterogeneous_inference.py b/examples/link_prediction/heterogeneous_inference.py index c2f047926..29c9048e7 100644 --- a/examples/link_prediction/heterogeneous_inference.py +++ b/examples/link_prediction/heterogeneous_inference.py @@ -23,7 +23,7 @@ import gc import time from dataclasses import dataclass -from typing import Optional, Union, cast +from typing import Optional, Union import torch import torch.distributed @@ -152,13 +152,17 @@ def _inference_process( node_type_to_input_node_ids: Optional[ Union[torch.Tensor, dict[NodeType, torch.Tensor]] ] = args.dataset.node_ids - assert isinstance(node_type_to_input_node_ids, dict), ( - f"Node IDs must be a dictionary for heterogeneous inference, got {type(node_type_to_input_node_ids)}" - ) - node_type_to_input_node_ids = cast( - dict[NodeType, torch.Tensor], node_type_to_input_node_ids - ) + if node_type_to_input_node_ids is None or isinstance( + node_type_to_input_node_ids, torch.Tensor + ): + raise TypeError( + f"Node IDs must be a dictionary for heterogeneous inference, got {type(node_type_to_input_node_ids)}" + ) input_node_ids: torch.Tensor = node_type_to_input_node_ids[args.inference_node_type] + assert isinstance(input_node_ids, torch.Tensor), ( + f"Expected Tensor node IDs for node type {args.inference_node_type}, " + f"got {type(input_node_ids)}" + ) data_loader = gigl.distributed.DistNeighborLoader( dataset=args.dataset, From fd1e9ae537382ef4f0f3cfffa3935fe3ac7ae6e0 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 16:55:57 +0000 Subject: [PATCH 15/32] Trim branch to PPR sampler fixes --- Makefile | 8 -- .../e2e_hom_cora_sup_gs_ppr_task_config.yaml | 75 ------------------- .../graph_store/homogeneous_inference.py | 58 +++++++------- .../graph_store/homogeneous_training.py | 29 ++----- .../heterogeneous_inference.py | 13 +--- gigl/common/metrics/decorators.py | 4 +- .../node_classification_modeling_task_spec.py | 2 +- gigl/utils/sampling.py | 41 ---------- tests/e2e_tests/e2e_tests.yaml | 3 - .../dataset_input_metadata_translator_test.py | 36 ++++----- .../unit/distributed/dist_ppr_sampler_test.py | 22 +++++- 11 files changed, 70 insertions(+), 221 deletions(-) delete mode 100644 examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml diff --git a/Makefile b/Makefile index e378fdd76..f80eb2952 100644 --- a/Makefile +++ b/Makefile @@ -270,14 +270,6 @@ run_hom_cora_sup_gs_e2e_test: --test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \ --test_names="hom_cora_sup_gs_test" -run_hom_cora_sup_gs_ppr_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} -run_hom_cora_sup_gs_ppr_e2e_test: compile_gigl_kubeflow_pipeline -run_hom_cora_sup_gs_ppr_e2e_test: - uv run python tests/e2e_tests/e2e_test.py \ - --compiled_pipeline_path=$(compiled_pipeline_path) \ - --test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \ - --test_names="hom_cora_sup_gs_ppr_test" - run_het_dblp_sup_gs_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} run_het_dblp_sup_gs_e2e_test: compile_gigl_kubeflow_pipeline run_het_dblp_sup_gs_e2e_test: diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml deleted file mode 100644 index 46c508819..000000000 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml +++ /dev/null @@ -1,75 +0,0 @@ -# This config runs homogeneous CORA supervised training and inference in Graph Store mode -# with PPR sampling. It intentionally reuses the standard graph-store training/inference -# entrypoints, changing only the sampler args and keeping the loop short for E2E coverage. -graphMetadata: - edgeTypes: - - dstNodeType: paper - relation: cites - srcNodeType: paper - nodeTypes: - - paper -datasetConfig: - dataPreprocessorConfig: - dataPreprocessorConfigClsPath: gigl.src.mocking.mocking_assets.passthrough_preprocessor_config_for_mocked_assets.PassthroughPreprocessorConfigForMockedAssets - dataPreprocessorArgs: - mocked_dataset_name: 'cora_homogeneous_node_anchor_edge_features_user_defined_labels' -trainerConfig: - trainerArgs: - log_every_n_batch: "1" - num_neighbors: "[10, 10]" - sampler_type: "ppr" - ppr_alpha: "0.5" - ppr_eps: "0.0001" - ppr_max_nodes: "20" - ppr_neighbors_per_hop: "100" - ppr_max_fetch_iterations: "2" - sampling_workers_per_process: "2" - main_batch_size: "8" - random_batch_size: "8" - num_max_train_batches: "4" - num_val_batches: "4" - val_every_n_batch: "1" - command: python -m examples.link_prediction.graph_store.homogeneous_training - graphStoreStorageConfig: - command: python -m examples.link_prediction.graph_store.storage_main - storageArgs: - sample_edge_direction: "in" - splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" - splitter_kwargs: >- - { - "sampling_direction": "in", - "should_convert_labels_to_edges": True, - "num_val": 0.25, - "num_test": 0.25 - } - num_server_sessions: "1" -inferencerConfig: - inferencerArgs: - log_every_n_batch: "1" - num_neighbors: "[10, 10]" - sampler_type: "ppr" - ppr_alpha: "0.5" - ppr_eps: "0.0001" - ppr_max_nodes: "20" - ppr_neighbors_per_hop: "100" - ppr_max_fetch_iterations: "2" - sampling_workers_per_inference_process: "2" - inferenceBatchSize: 256 - command: python -m examples.link_prediction.graph_store.homogeneous_inference - graphStoreStorageConfig: - command: python -m examples.link_prediction.graph_store.storage_main - storageArgs: - sample_edge_direction: "in" - num_server_sessions: "1" -sharedConfig: - shouldSkipInference: false - shouldSkipModelEvaluation: true -taskMetadata: - nodeAnchorBasedLinkPredictionTaskMetadata: - supervisionEdgeTypes: - - dstNodeType: paper - relation: cites - srcNodeType: paper -featureFlags: - should_run_glt_backend: 'True' - data_preprocessor_num_shards: '2' diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 5faa84b72..34bc2672e 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -87,7 +87,7 @@ import sys import time from dataclasses import dataclass -from typing import Optional, Union +from typing import Union import torch import torch.multiprocessing as mp @@ -101,7 +101,6 @@ from gigl.common.utils.gcs import GcsUtils from gigl.distributed.graph_store.compute import init_compute_process from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset -from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils import get_graph_store_info from gigl.env.distributed import GraphStoreInfo from gigl.nn import LinkPredictionGNN @@ -111,10 +110,16 @@ from gigl.src.common.utils.bq import BqUtils from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.inference.lib.assets import InferenceAssets -from gigl.utils.sampling import parse_fanout, parse_sampler_options +from gigl.utils.sampling import parse_fanout logger = Logger() +# Default number of inference processes per machine incase one isnt provided in inference args +# i.e. `local_world_size` is not provided, and we can't infer automatically. +# If there are GPUs attached to the machine, we automatically infer to setting +# LOCAL_WORLD_SIZE == # of gpus on the machine. +DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE = 4 + @dataclass(frozen=True) class InferenceProcessArgs: @@ -138,7 +143,6 @@ class InferenceProcessArgs: inference_batch_size (int): Batch size to use for inference. num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Fanout for subgraph sampling, where the ith item corresponds to the number of items to sample for the ith hop. - sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. sampling_workers_per_inference_process (int): Number of sampling workers per inference process. sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for @@ -165,7 +169,6 @@ class InferenceProcessArgs: # Inference configuration inference_batch_size: int num_neighbors: Union[list[int], dict[EdgeType, list[int]]] - sampler_options: Optional[SamplerOptions] sampling_workers_per_inference_process: int sampling_worker_shared_channel_size: str log_every_n_batch: int @@ -239,7 +242,6 @@ def _inference_process( # For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders # don't compete for memory during initialization, causing OOM process_start_gap_seconds=0, - sampler_options=args.sampler_options, ) # Initialize a LinkPredictionGNN model and load parameters from # the saved model. @@ -453,23 +455,25 @@ def _run_example_inference( if arg_local_world_size is not None: local_world_size = int(arg_local_world_size) logger.info(f"Using local_world_size from inferencer_args: {local_world_size}") + if torch.cuda.is_available() and local_world_size != torch.cuda.device_count(): + logger.warning( + f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. " + "This may lead to unexpected failures with NCCL communication incase GPUs are being used for " + + "training/inference. Consider setting local_world_size to the number of GPUs." + ) else: - local_world_size = cluster_info.num_processes_per_compute - logger.info( - f"Using local_world_size from cluster_info.num_processes_per_compute: {local_world_size}" - ) - if local_world_size != cluster_info.num_processes_per_compute: - raise ValueError( - f"Graph Store local_world_size={local_world_size} must match " - f"cluster_info.num_processes_per_compute=" - f"{cluster_info.num_processes_per_compute}" - ) - if torch.cuda.is_available() and local_world_size != torch.cuda.device_count(): - logger.warning( - f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. " - "This may lead to unexpected failures with NCCL communication incase GPUs are being used for " - + "training/inference. Consider setting local_world_size to the number of GPUs." - ) + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + # If GPUs are available, we set the local_world_size to the number of GPUs + local_world_size = torch.cuda.device_count() + logger.info( + f"Detected {local_world_size} GPUs. Thus, setting local_world_size to {local_world_size}" + ) + else: + # If no GPUs are available, we set the local_world_size to the number of inference processes per machine + logger.info( + f"No GPUs detected. Thus, setting local_world_size to `{DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE}`" + ) + local_world_size = DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE if cluster_info.compute_node_rank == 0: gcs_utils = GcsUtils() @@ -490,7 +494,6 @@ def _run_example_inference( # Parses the fanout as a string. For the homogeneous case, the fanouts should be specified # as a string of a list of integers, such as "[10, 10]". num_neighbors = parse_fanout(inferencer_args.get("num_neighbors", "[10, 10]")) - sampler_options = parse_sampler_options(inferencer_args) # While the ideal value for `sampling_workers_per_inference_process` has been identified to be # between `2` and `4`, this may need some tuning depending on the pipeline. We default this @@ -513,14 +516,6 @@ def _run_example_inference( log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50")) - logger.info( - f"Got inference args local_world_size={local_world_size}, " - f"num_neighbors={num_neighbors}, sampler_options={sampler_options}, " - f"sampling_workers_per_inference_process={sampling_workers_per_inference_process}, " - f"sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, " - f"log_every_n_batch={log_every_n_batch}" - ) - # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. inference_args = InferenceProcessArgs( local_world_size=local_world_size, @@ -533,7 +528,6 @@ def _run_example_inference( edge_feature_dim=edge_feature_dim, inference_batch_size=inference_batch_size, num_neighbors=num_neighbors, - sampler_options=sampler_options, sampling_workers_per_inference_process=sampling_workers_per_inference_process, sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, log_every_n_batch=log_every_n_batch, diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index c7ae356cc..04340f99a 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -143,7 +143,6 @@ shutdown_compute_process, ) from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset -from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils import get_available_device, get_graph_store_info from gigl.env.distributed import GraphStoreInfo from gigl.nn import LinkPredictionGNN, RetrievalLoss @@ -159,7 +158,7 @@ from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict from gigl.utils.iterator import InfiniteIterator -from gigl.utils.sampling import parse_fanout, parse_sampler_options +from gigl.utils.sampling import parse_fanout logger = Logger() @@ -192,7 +191,6 @@ def _setup_dataloaders( split: Literal["train", "val", "test"], cluster_info: GraphStoreInfo, num_neighbors: list[int] | dict[EdgeType, list[int]], - sampler_options: Optional[SamplerOptions], sampling_workers_per_process: int, main_batch_size: int, random_batch_size: int, @@ -207,7 +205,6 @@ def _setup_dataloaders( split (Literal["train", "val", "test"]): The current split which we are loading data for. cluster_info (GraphStoreInfo): Cluster topology info for graph store mode. num_neighbors: Fanout for subgraph sampling. - sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. sampling_workers_per_process (int): Number of sampling workers per training/testing process. main_batch_size (int): Batch size for main dataloader with query and labeled nodes. random_batch_size (int): Batch size for random negative dataloader. @@ -243,7 +240,6 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, - sampler_options=sampler_options, ) logger.info(f"---Rank {rank} finished setting up main loader for split={split}") @@ -270,7 +266,6 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, - sampler_options=sampler_options, ) logger.info( @@ -380,7 +375,6 @@ class TrainingProcessArgs: sampling_workers_per_process (int): Number of sampling workers per training/testing process. sampling_worker_shared_channel_size (str): Shared-memory buffer size for the channel during sampling. process_start_gap_seconds (int): Time to sleep between dataloader initializations. - sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. main_batch_size (int): Batch size for main dataloader. random_batch_size (int): Batch size for random negative dataloader. learning_rate (float): Learning rate for the optimizer. @@ -406,7 +400,6 @@ class TrainingProcessArgs: # Sampling config num_neighbors: list[int] | dict[EdgeType, list[int]] - sampler_options: Optional[SamplerOptions] sampling_workers_per_process: int sampling_worker_shared_channel_size: str process_start_gap_seconds: int @@ -470,7 +463,6 @@ def _training_process( split="train", cluster_info=args.cluster_info, num_neighbors=args.num_neighbors, - sampler_options=args.sampler_options, sampling_workers_per_process=args.sampling_workers_per_process, main_batch_size=args.main_batch_size, random_batch_size=args.random_batch_size, @@ -489,7 +481,6 @@ def _training_process( split="val", cluster_info=args.cluster_info, num_neighbors=args.num_neighbors, - sampler_options=args.sampler_options, sampling_workers_per_process=args.sampling_workers_per_process, main_batch_size=args.main_batch_size, random_batch_size=args.random_batch_size, @@ -646,7 +637,6 @@ def _training_process( split="test", cluster_info=args.cluster_info, num_neighbors=args.num_neighbors, - sampler_options=args.sampler_options, sampling_workers_per_process=args.sampling_workers_per_process, main_batch_size=args.main_batch_size, random_batch_size=args.random_batch_size, @@ -847,17 +837,13 @@ def _run_example_training( # Training Hyperparameters trainer_args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args) + if torch.cuda.is_available(): + default_local_world_size = torch.cuda.device_count() + else: + default_local_world_size = 2 local_world_size = int( - trainer_args.get( - "local_world_size", str(cluster_info.num_processes_per_compute) - ) + trainer_args.get("local_world_size", str(default_local_world_size)) ) - if local_world_size != cluster_info.num_processes_per_compute: - raise ValueError( - f"Graph Store local_world_size={local_world_size} must match " - f"cluster_info.num_processes_per_compute=" - f"{cluster_info.num_processes_per_compute}" - ) if torch.cuda.is_available(): if local_world_size > torch.cuda.device_count(): @@ -867,7 +853,6 @@ def _run_example_training( fanout = trainer_args.get("num_neighbors", "[10, 10]") num_neighbors = parse_fanout(fanout) - sampler_options = parse_sampler_options(trainer_args) sampling_workers_per_process: int = int( trainer_args.get("sampling_workers_per_process", "4") @@ -895,7 +880,6 @@ def _run_example_training( logger.info( f"Got training args local_world_size={local_world_size}, \ num_neighbors={num_neighbors}, \ - sampler_options={sampler_options}, \ sampling_workers_per_process={sampling_workers_per_process}, \ main_batch_size={main_batch_size}, \ random_batch_size={random_batch_size}, \ @@ -947,7 +931,6 @@ def _run_example_training( node_feature_dim=node_feature_dim, edge_feature_dim=edge_feature_dim, num_neighbors=num_neighbors, - sampler_options=sampler_options, sampling_workers_per_process=sampling_workers_per_process, sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, diff --git a/examples/link_prediction/heterogeneous_inference.py b/examples/link_prediction/heterogeneous_inference.py index 29c9048e7..b676044d7 100644 --- a/examples/link_prediction/heterogeneous_inference.py +++ b/examples/link_prediction/heterogeneous_inference.py @@ -152,17 +152,12 @@ def _inference_process( node_type_to_input_node_ids: Optional[ Union[torch.Tensor, dict[NodeType, torch.Tensor]] ] = args.dataset.node_ids - if node_type_to_input_node_ids is None or isinstance( + assert node_type_to_input_node_ids is not None and not isinstance( node_type_to_input_node_ids, torch.Tensor - ): - raise TypeError( - f"Node IDs must be a dictionary for heterogeneous inference, got {type(node_type_to_input_node_ids)}" - ) - input_node_ids: torch.Tensor = node_type_to_input_node_ids[args.inference_node_type] - assert isinstance(input_node_ids, torch.Tensor), ( - f"Expected Tensor node IDs for node type {args.inference_node_type}, " - f"got {type(input_node_ids)}" + ), ( + f"Node IDs must be a dictionary for heterogeneous inference, got {type(node_type_to_input_node_ids)}" ) + input_node_ids: torch.Tensor = node_type_to_input_node_ids[args.inference_node_type] data_loader = gigl.distributed.DistNeighborLoader( dataset=args.dataset, diff --git a/gigl/common/metrics/decorators.py b/gigl/common/metrics/decorators.py index 6f84b8737..d0561e61c 100644 --- a/gigl/common/metrics/decorators.py +++ b/gigl/common/metrics/decorators.py @@ -22,7 +22,6 @@ def __safely_flush_metrics( Callable[[], Optional[OpsMetricPublisher]] ], ) -> None: - metrics_instance = None if get_metrics_service_instance_fn is not None: metrics_instance = get_metrics_service_instance_fn() if metrics_instance is not None: @@ -46,9 +45,8 @@ def wrap(*args: Any, **kwargs: Any) -> Any: try: result = func(*args, **kwargs) except Exception as e: - func_name = getattr(func, "__name__", repr(func)) logger.info( - f"Exception raised, will flush metrics for: {func_name} and re-raise exception" + f"Exception raised, will flush metrics for: {getattr(func, '__name__')} and re-raise exception" ) logger.error(f"Exception: {e}") logger.error(traceback.format_exc()) diff --git a/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py b/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py index 66809b5a2..965f67915 100644 --- a/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py +++ b/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py @@ -203,7 +203,7 @@ def score( results: InferBatchResults = self.infer_batch(batch=batch, device=device) assert results.predictions is not None num_correct_in_batch = int( - torch.eq(results.predictions, root_node_labels).sum().item() + (results.predictions == root_node_labels).sum() ) # https://github.com/Snapchat/GiGL/issues/408 num_correct += num_correct_in_batch num_evaluated += len(batch.root_node_labels) diff --git a/gigl/utils/sampling.py b/gigl/utils/sampling.py index e2c6996e5..5d0ed6a44 100644 --- a/gigl/utils/sampling.py +++ b/gigl/utils/sampling.py @@ -1,12 +1,10 @@ import ast -from collections.abc import Mapping from dataclasses import dataclass from typing import Any, Optional, Union import torch from gigl.common.logger import Logger -from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.src.common.types.graph_data import EdgeType, NodeType logger = Logger() @@ -90,45 +88,6 @@ def parse_fanout(fanout_str: str) -> Union[list[int], dict[EdgeType, list[int]]] ) -def _parse_optional_int(value: Optional[str]) -> Optional[int]: - if value is None: - return None - normalized = value.strip().lower() - if normalized in {"", "none", "null"}: - return None - return int(value) - - -def parse_sampler_options(args: Mapping[str, str]) -> Optional[SamplerOptions]: - sampler_type = args.get("sampler_type", "khop").strip().lower().replace("-", "_") - if sampler_type == "": - sampler_type = "khop" - - if sampler_type in {"khop", "k_hop", "neighbor", "neighbor_sampler"}: - return None - - if sampler_type != "ppr": - raise ValueError( - f"Unsupported sampler_type={sampler_type}. Expected one of: khop, ppr." - ) - - max_ppr_nodes = args.get("ppr_max_nodes") - if max_ppr_nodes is None: - max_ppr_nodes = args.get("ppr_max_ppr_nodes", "50") - - num_neighbors_per_hop = args.get("ppr_neighbors_per_hop") - if num_neighbors_per_hop is None: - num_neighbors_per_hop = args.get("ppr_num_neighbors_per_hop", "1000") - - return PPRSamplerOptions( - alpha=float(args.get("ppr_alpha", "0.5")), - eps=float(args.get("ppr_eps", "0.0001")), - max_ppr_nodes=int(max_ppr_nodes), - num_neighbors_per_hop=int(num_neighbors_per_hop), - max_fetch_iterations=_parse_optional_int(args.get("ppr_max_fetch_iterations")), - ) - - @dataclass(frozen=True) class ABLPInputNodes: """Represents ABLP (Anchor Based Link Prediction) input for a single storage server. diff --git a/tests/e2e_tests/e2e_tests.yaml b/tests/e2e_tests/e2e_tests.yaml index 6d09d8213..61fc4f311 100644 --- a/tests/e2e_tests/e2e_tests.yaml +++ b/tests/e2e_tests/e2e_tests.yaml @@ -22,9 +22,6 @@ tests: hom_cora_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" - hom_cora_sup_gs_ppr_test: - task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" het_dblp_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" diff --git a/tests/unit/distributed/dataset_input_metadata_translator_test.py b/tests/unit/distributed/dataset_input_metadata_translator_test.py index 2899848f3..49156166f 100644 --- a/tests/unit/distributed/dataset_input_metadata_translator_test.py +++ b/tests/unit/distributed/dataset_input_metadata_translator_test.py @@ -113,18 +113,14 @@ def test_translator_correctness(self, _, mocked_dataset_info: MockedDatasetInfo) expected_entity_types=graph_metadata_pb_wrapper.node_types, ) - if isinstance(serialized_graph_metadata.node_entity_info, abc.Mapping): - serialized_node_info_iterable = cast( - list[SerializedTFRecordInfo], - list(serialized_graph_metadata.node_entity_info.values()), - ) + if isinstance( + serialized_graph_metadata.node_entity_info, SerializedTFRecordInfo + ): + serialized_node_info_iterable = [serialized_graph_metadata.node_entity_info] else: - serialized_node_info_iterable = [ - cast( - SerializedTFRecordInfo, - serialized_graph_metadata.node_entity_info, - ) - ] + serialized_node_info_iterable = list( + serialized_graph_metadata.node_entity_info.values() + ) self.assertEqual( len(graph_metadata_pb_wrapper.node_types), @@ -194,18 +190,14 @@ def test_translator_correctness(self, _, mocked_dataset_info: MockedDatasetInfo) expected_entity_types=graph_metadata_pb_wrapper.edge_types, ) - if isinstance(serialized_graph_metadata.edge_entity_info, abc.Mapping): - serialized_edge_info_iterable = cast( - list[SerializedTFRecordInfo], - list(serialized_graph_metadata.edge_entity_info.values()), - ) + if isinstance( + serialized_graph_metadata.edge_entity_info, SerializedTFRecordInfo + ): + serialized_edge_info_iterable = [serialized_graph_metadata.edge_entity_info] else: - serialized_edge_info_iterable = [ - cast( - SerializedTFRecordInfo, - serialized_graph_metadata.edge_entity_info, - ) - ] + serialized_edge_info_iterable = list( + serialized_graph_metadata.edge_entity_info.values() + ) self.assertEqual( len(graph_metadata_pb_wrapper.edge_types), diff --git a/tests/unit/distributed/dist_ppr_sampler_test.py b/tests/unit/distributed/dist_ppr_sampler_test.py index e24dd0470..4837f5ef4 100644 --- a/tests/unit/distributed/dist_ppr_sampler_test.py +++ b/tests/unit/distributed/dist_ppr_sampler_test.py @@ -28,7 +28,7 @@ import heapq from collections import defaultdict -from typing import Literal +from typing import Literal, TypeGuard import networkx as nx import torch @@ -95,6 +95,14 @@ _TEST_MAX_PPR_NODES = 5 +def _is_node_type_to_tensor_map( + value: object, +) -> TypeGuard[dict[str, torch.Tensor]]: + return isinstance(value, dict) and all( + isinstance(node_ids, torch.Tensor) for node_ids in value.values() + ) + + # --------------------------------------------------------------------------- # Reference PPR implementations (NetworkX-based) # --------------------------------------------------------------------------- @@ -508,12 +516,15 @@ def _run_ppr_ablp_loader_correctness_check( ) train_node_ids = dataset.train_node_ids - assert isinstance(train_node_ids, dict) + if not _is_node_type_to_tensor_map(train_node_ids): + raise TypeError( + f"Expected train_node_ids to be a dictionary, got {type(train_node_ids)}" + ) loader = DistABLPLoader( dataset=dataset, num_neighbors=[], # Unused by PPR sampler; required by interface - input_nodes=(USER, train_node_ids[USER]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + input_nodes=(USER, train_node_ids[USER]), supervision_edge_type=USER_TO_STORY, sampler_options=PPRSamplerOptions( alpha=alpha, @@ -611,7 +622,10 @@ def _run_ppr_labeled_homogeneous_ablp_loader_check(_: int) -> None: ) train_node_ids = dataset.train_node_ids - assert isinstance(train_node_ids, dict) + if not _is_node_type_to_tensor_map(train_node_ids): + raise TypeError( + f"Expected train_node_ids to be a dictionary, got {type(train_node_ids)}" + ) loader = DistABLPLoader( dataset=dataset, From a49a650949c07fd87cdb33eca0f15a2a6ad10c56 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 17:04:46 +0000 Subject: [PATCH 16/32] Add graph-store PPR E2E wiring --- Makefile | 8 ++ .../e2e_hom_cora_sup_gs_ppr_task_config.yaml | 75 +++++++++++++++++++ .../graph_store/homogeneous_inference.py | 58 +++++++------- .../graph_store/homogeneous_training.py | 29 +++++-- gigl/utils/sampling.py | 41 ++++++++++ tests/e2e_tests/e2e_tests.yaml | 3 + 6 files changed, 182 insertions(+), 32 deletions(-) create mode 100644 examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml diff --git a/Makefile b/Makefile index 93ab75ffc..dab742500 100644 --- a/Makefile +++ b/Makefile @@ -270,6 +270,14 @@ run_hom_cora_sup_gs_e2e_test: --test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \ --test_names="hom_cora_sup_gs_test" +run_hom_cora_sup_gs_ppr_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} +run_hom_cora_sup_gs_ppr_e2e_test: compile_gigl_kubeflow_pipeline +run_hom_cora_sup_gs_ppr_e2e_test: + uv run python tests/e2e_tests/e2e_test.py \ + --compiled_pipeline_path=$(compiled_pipeline_path) \ + --test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \ + --test_names="hom_cora_sup_gs_ppr_test" + run_het_dblp_sup_gs_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} run_het_dblp_sup_gs_e2e_test: compile_gigl_kubeflow_pipeline run_het_dblp_sup_gs_e2e_test: diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml new file mode 100644 index 000000000..46c508819 --- /dev/null +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml @@ -0,0 +1,75 @@ +# This config runs homogeneous CORA supervised training and inference in Graph Store mode +# with PPR sampling. It intentionally reuses the standard graph-store training/inference +# entrypoints, changing only the sampler args and keeping the loop short for E2E coverage. +graphMetadata: + edgeTypes: + - dstNodeType: paper + relation: cites + srcNodeType: paper + nodeTypes: + - paper +datasetConfig: + dataPreprocessorConfig: + dataPreprocessorConfigClsPath: gigl.src.mocking.mocking_assets.passthrough_preprocessor_config_for_mocked_assets.PassthroughPreprocessorConfigForMockedAssets + dataPreprocessorArgs: + mocked_dataset_name: 'cora_homogeneous_node_anchor_edge_features_user_defined_labels' +trainerConfig: + trainerArgs: + log_every_n_batch: "1" + num_neighbors: "[10, 10]" + sampler_type: "ppr" + ppr_alpha: "0.5" + ppr_eps: "0.0001" + ppr_max_nodes: "20" + ppr_neighbors_per_hop: "100" + ppr_max_fetch_iterations: "2" + sampling_workers_per_process: "2" + main_batch_size: "8" + random_batch_size: "8" + num_max_train_batches: "4" + num_val_batches: "4" + val_every_n_batch: "1" + command: python -m examples.link_prediction.graph_store.homogeneous_training + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: >- + { + "sampling_direction": "in", + "should_convert_labels_to_edges": True, + "num_val": 0.25, + "num_test": 0.25 + } + num_server_sessions: "1" +inferencerConfig: + inferencerArgs: + log_every_n_batch: "1" + num_neighbors: "[10, 10]" + sampler_type: "ppr" + ppr_alpha: "0.5" + ppr_eps: "0.0001" + ppr_max_nodes: "20" + ppr_neighbors_per_hop: "100" + ppr_max_fetch_iterations: "2" + sampling_workers_per_inference_process: "2" + inferenceBatchSize: 256 + command: python -m examples.link_prediction.graph_store.homogeneous_inference + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + num_server_sessions: "1" +sharedConfig: + shouldSkipInference: false + shouldSkipModelEvaluation: true +taskMetadata: + nodeAnchorBasedLinkPredictionTaskMetadata: + supervisionEdgeTypes: + - dstNodeType: paper + relation: cites + srcNodeType: paper +featureFlags: + should_run_glt_backend: 'True' + data_preprocessor_num_shards: '2' diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 34bc2672e..5faa84b72 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -87,7 +87,7 @@ import sys import time from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import torch import torch.multiprocessing as mp @@ -101,6 +101,7 @@ from gigl.common.utils.gcs import GcsUtils from gigl.distributed.graph_store.compute import init_compute_process from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils import get_graph_store_info from gigl.env.distributed import GraphStoreInfo from gigl.nn import LinkPredictionGNN @@ -110,16 +111,10 @@ from gigl.src.common.utils.bq import BqUtils from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.inference.lib.assets import InferenceAssets -from gigl.utils.sampling import parse_fanout +from gigl.utils.sampling import parse_fanout, parse_sampler_options logger = Logger() -# Default number of inference processes per machine incase one isnt provided in inference args -# i.e. `local_world_size` is not provided, and we can't infer automatically. -# If there are GPUs attached to the machine, we automatically infer to setting -# LOCAL_WORLD_SIZE == # of gpus on the machine. -DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE = 4 - @dataclass(frozen=True) class InferenceProcessArgs: @@ -143,6 +138,7 @@ class InferenceProcessArgs: inference_batch_size (int): Batch size to use for inference. num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Fanout for subgraph sampling, where the ith item corresponds to the number of items to sample for the ith hop. + sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. sampling_workers_per_inference_process (int): Number of sampling workers per inference process. sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for @@ -169,6 +165,7 @@ class InferenceProcessArgs: # Inference configuration inference_batch_size: int num_neighbors: Union[list[int], dict[EdgeType, list[int]]] + sampler_options: Optional[SamplerOptions] sampling_workers_per_inference_process: int sampling_worker_shared_channel_size: str log_every_n_batch: int @@ -242,6 +239,7 @@ def _inference_process( # For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders # don't compete for memory during initialization, causing OOM process_start_gap_seconds=0, + sampler_options=args.sampler_options, ) # Initialize a LinkPredictionGNN model and load parameters from # the saved model. @@ -455,25 +453,23 @@ def _run_example_inference( if arg_local_world_size is not None: local_world_size = int(arg_local_world_size) logger.info(f"Using local_world_size from inferencer_args: {local_world_size}") - if torch.cuda.is_available() and local_world_size != torch.cuda.device_count(): - logger.warning( - f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. " - "This may lead to unexpected failures with NCCL communication incase GPUs are being used for " - + "training/inference. Consider setting local_world_size to the number of GPUs." - ) else: - if torch.cuda.is_available() and torch.cuda.device_count() > 0: - # If GPUs are available, we set the local_world_size to the number of GPUs - local_world_size = torch.cuda.device_count() - logger.info( - f"Detected {local_world_size} GPUs. Thus, setting local_world_size to {local_world_size}" - ) - else: - # If no GPUs are available, we set the local_world_size to the number of inference processes per machine - logger.info( - f"No GPUs detected. Thus, setting local_world_size to `{DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE}`" - ) - local_world_size = DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE + local_world_size = cluster_info.num_processes_per_compute + logger.info( + f"Using local_world_size from cluster_info.num_processes_per_compute: {local_world_size}" + ) + if local_world_size != cluster_info.num_processes_per_compute: + raise ValueError( + f"Graph Store local_world_size={local_world_size} must match " + f"cluster_info.num_processes_per_compute=" + f"{cluster_info.num_processes_per_compute}" + ) + if torch.cuda.is_available() and local_world_size != torch.cuda.device_count(): + logger.warning( + f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. " + "This may lead to unexpected failures with NCCL communication incase GPUs are being used for " + + "training/inference. Consider setting local_world_size to the number of GPUs." + ) if cluster_info.compute_node_rank == 0: gcs_utils = GcsUtils() @@ -494,6 +490,7 @@ def _run_example_inference( # Parses the fanout as a string. For the homogeneous case, the fanouts should be specified # as a string of a list of integers, such as "[10, 10]". num_neighbors = parse_fanout(inferencer_args.get("num_neighbors", "[10, 10]")) + sampler_options = parse_sampler_options(inferencer_args) # While the ideal value for `sampling_workers_per_inference_process` has been identified to be # between `2` and `4`, this may need some tuning depending on the pipeline. We default this @@ -516,6 +513,14 @@ def _run_example_inference( log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50")) + logger.info( + f"Got inference args local_world_size={local_world_size}, " + f"num_neighbors={num_neighbors}, sampler_options={sampler_options}, " + f"sampling_workers_per_inference_process={sampling_workers_per_inference_process}, " + f"sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, " + f"log_every_n_batch={log_every_n_batch}" + ) + # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. inference_args = InferenceProcessArgs( local_world_size=local_world_size, @@ -528,6 +533,7 @@ def _run_example_inference( edge_feature_dim=edge_feature_dim, inference_batch_size=inference_batch_size, num_neighbors=num_neighbors, + sampler_options=sampler_options, sampling_workers_per_inference_process=sampling_workers_per_inference_process, sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, log_every_n_batch=log_every_n_batch, diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 04340f99a..c7ae356cc 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -143,6 +143,7 @@ shutdown_compute_process, ) from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils import get_available_device, get_graph_store_info from gigl.env.distributed import GraphStoreInfo from gigl.nn import LinkPredictionGNN, RetrievalLoss @@ -158,7 +159,7 @@ from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict from gigl.utils.iterator import InfiniteIterator -from gigl.utils.sampling import parse_fanout +from gigl.utils.sampling import parse_fanout, parse_sampler_options logger = Logger() @@ -191,6 +192,7 @@ def _setup_dataloaders( split: Literal["train", "val", "test"], cluster_info: GraphStoreInfo, num_neighbors: list[int] | dict[EdgeType, list[int]], + sampler_options: Optional[SamplerOptions], sampling_workers_per_process: int, main_batch_size: int, random_batch_size: int, @@ -205,6 +207,7 @@ def _setup_dataloaders( split (Literal["train", "val", "test"]): The current split which we are loading data for. cluster_info (GraphStoreInfo): Cluster topology info for graph store mode. num_neighbors: Fanout for subgraph sampling. + sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. sampling_workers_per_process (int): Number of sampling workers per training/testing process. main_batch_size (int): Batch size for main dataloader with query and labeled nodes. random_batch_size (int): Batch size for random negative dataloader. @@ -240,6 +243,7 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + sampler_options=sampler_options, ) logger.info(f"---Rank {rank} finished setting up main loader for split={split}") @@ -266,6 +270,7 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + sampler_options=sampler_options, ) logger.info( @@ -375,6 +380,7 @@ class TrainingProcessArgs: sampling_workers_per_process (int): Number of sampling workers per training/testing process. sampling_worker_shared_channel_size (str): Shared-memory buffer size for the channel during sampling. process_start_gap_seconds (int): Time to sleep between dataloader initializations. + sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. main_batch_size (int): Batch size for main dataloader. random_batch_size (int): Batch size for random negative dataloader. learning_rate (float): Learning rate for the optimizer. @@ -400,6 +406,7 @@ class TrainingProcessArgs: # Sampling config num_neighbors: list[int] | dict[EdgeType, list[int]] + sampler_options: Optional[SamplerOptions] sampling_workers_per_process: int sampling_worker_shared_channel_size: str process_start_gap_seconds: int @@ -463,6 +470,7 @@ def _training_process( split="train", cluster_info=args.cluster_info, num_neighbors=args.num_neighbors, + sampler_options=args.sampler_options, sampling_workers_per_process=args.sampling_workers_per_process, main_batch_size=args.main_batch_size, random_batch_size=args.random_batch_size, @@ -481,6 +489,7 @@ def _training_process( split="val", cluster_info=args.cluster_info, num_neighbors=args.num_neighbors, + sampler_options=args.sampler_options, sampling_workers_per_process=args.sampling_workers_per_process, main_batch_size=args.main_batch_size, random_batch_size=args.random_batch_size, @@ -637,6 +646,7 @@ def _training_process( split="test", cluster_info=args.cluster_info, num_neighbors=args.num_neighbors, + sampler_options=args.sampler_options, sampling_workers_per_process=args.sampling_workers_per_process, main_batch_size=args.main_batch_size, random_batch_size=args.random_batch_size, @@ -837,13 +847,17 @@ def _run_example_training( # Training Hyperparameters trainer_args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args) - if torch.cuda.is_available(): - default_local_world_size = torch.cuda.device_count() - else: - default_local_world_size = 2 local_world_size = int( - trainer_args.get("local_world_size", str(default_local_world_size)) + trainer_args.get( + "local_world_size", str(cluster_info.num_processes_per_compute) + ) ) + if local_world_size != cluster_info.num_processes_per_compute: + raise ValueError( + f"Graph Store local_world_size={local_world_size} must match " + f"cluster_info.num_processes_per_compute=" + f"{cluster_info.num_processes_per_compute}" + ) if torch.cuda.is_available(): if local_world_size > torch.cuda.device_count(): @@ -853,6 +867,7 @@ def _run_example_training( fanout = trainer_args.get("num_neighbors", "[10, 10]") num_neighbors = parse_fanout(fanout) + sampler_options = parse_sampler_options(trainer_args) sampling_workers_per_process: int = int( trainer_args.get("sampling_workers_per_process", "4") @@ -880,6 +895,7 @@ def _run_example_training( logger.info( f"Got training args local_world_size={local_world_size}, \ num_neighbors={num_neighbors}, \ + sampler_options={sampler_options}, \ sampling_workers_per_process={sampling_workers_per_process}, \ main_batch_size={main_batch_size}, \ random_batch_size={random_batch_size}, \ @@ -931,6 +947,7 @@ def _run_example_training( node_feature_dim=node_feature_dim, edge_feature_dim=edge_feature_dim, num_neighbors=num_neighbors, + sampler_options=sampler_options, sampling_workers_per_process=sampling_workers_per_process, sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, diff --git a/gigl/utils/sampling.py b/gigl/utils/sampling.py index 5d0ed6a44..e2c6996e5 100644 --- a/gigl/utils/sampling.py +++ b/gigl/utils/sampling.py @@ -1,10 +1,12 @@ import ast +from collections.abc import Mapping from dataclasses import dataclass from typing import Any, Optional, Union import torch from gigl.common.logger import Logger +from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.src.common.types.graph_data import EdgeType, NodeType logger = Logger() @@ -88,6 +90,45 @@ def parse_fanout(fanout_str: str) -> Union[list[int], dict[EdgeType, list[int]]] ) +def _parse_optional_int(value: Optional[str]) -> Optional[int]: + if value is None: + return None + normalized = value.strip().lower() + if normalized in {"", "none", "null"}: + return None + return int(value) + + +def parse_sampler_options(args: Mapping[str, str]) -> Optional[SamplerOptions]: + sampler_type = args.get("sampler_type", "khop").strip().lower().replace("-", "_") + if sampler_type == "": + sampler_type = "khop" + + if sampler_type in {"khop", "k_hop", "neighbor", "neighbor_sampler"}: + return None + + if sampler_type != "ppr": + raise ValueError( + f"Unsupported sampler_type={sampler_type}. Expected one of: khop, ppr." + ) + + max_ppr_nodes = args.get("ppr_max_nodes") + if max_ppr_nodes is None: + max_ppr_nodes = args.get("ppr_max_ppr_nodes", "50") + + num_neighbors_per_hop = args.get("ppr_neighbors_per_hop") + if num_neighbors_per_hop is None: + num_neighbors_per_hop = args.get("ppr_num_neighbors_per_hop", "1000") + + return PPRSamplerOptions( + alpha=float(args.get("ppr_alpha", "0.5")), + eps=float(args.get("ppr_eps", "0.0001")), + max_ppr_nodes=int(max_ppr_nodes), + num_neighbors_per_hop=int(num_neighbors_per_hop), + max_fetch_iterations=_parse_optional_int(args.get("ppr_max_fetch_iterations")), + ) + + @dataclass(frozen=True) class ABLPInputNodes: """Represents ABLP (Anchor Based Link Prediction) input for a single storage server. diff --git a/tests/e2e_tests/e2e_tests.yaml b/tests/e2e_tests/e2e_tests.yaml index 61fc4f311..6d09d8213 100644 --- a/tests/e2e_tests/e2e_tests.yaml +++ b/tests/e2e_tests/e2e_tests.yaml @@ -22,6 +22,9 @@ tests: hom_cora_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" + hom_cora_sup_gs_ppr_test: + task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" het_dblp_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" From 2ef9548f39bbc13979578bbf289cb3374ffc42f9 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 17:09:51 +0000 Subject: [PATCH 17/32] Keep PPR test ty ignores --- .../unit/distributed/dist_ppr_sampler_test.py | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/tests/unit/distributed/dist_ppr_sampler_test.py b/tests/unit/distributed/dist_ppr_sampler_test.py index 4837f5ef4..b8f324fad 100644 --- a/tests/unit/distributed/dist_ppr_sampler_test.py +++ b/tests/unit/distributed/dist_ppr_sampler_test.py @@ -28,7 +28,7 @@ import heapq from collections import defaultdict -from typing import Literal, TypeGuard +from typing import Literal import networkx as nx import torch @@ -95,14 +95,6 @@ _TEST_MAX_PPR_NODES = 5 -def _is_node_type_to_tensor_map( - value: object, -) -> TypeGuard[dict[str, torch.Tensor]]: - return isinstance(value, dict) and all( - isinstance(node_ids, torch.Tensor) for node_ids in value.values() - ) - - # --------------------------------------------------------------------------- # Reference PPR implementations (NetworkX-based) # --------------------------------------------------------------------------- @@ -516,7 +508,7 @@ def _run_ppr_ablp_loader_correctness_check( ) train_node_ids = dataset.train_node_ids - if not _is_node_type_to_tensor_map(train_node_ids): + if train_node_ids is None or isinstance(train_node_ids, torch.Tensor): raise TypeError( f"Expected train_node_ids to be a dictionary, got {type(train_node_ids)}" ) @@ -524,7 +516,7 @@ def _run_ppr_ablp_loader_correctness_check( loader = DistABLPLoader( dataset=dataset, num_neighbors=[], # Unused by PPR sampler; required by interface - input_nodes=(USER, train_node_ids[USER]), + input_nodes=(USER, train_node_ids[USER]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. supervision_edge_type=USER_TO_STORY, sampler_options=PPRSamplerOptions( alpha=alpha, @@ -622,15 +614,12 @@ def _run_ppr_labeled_homogeneous_ablp_loader_check(_: int) -> None: ) train_node_ids = dataset.train_node_ids - if not _is_node_type_to_tensor_map(train_node_ids): - raise TypeError( - f"Expected train_node_ids to be a dictionary, got {type(train_node_ids)}" - ) + assert isinstance(train_node_ids, dict) loader = DistABLPLoader( dataset=dataset, num_neighbors=[], - input_nodes=train_node_ids[DEFAULT_HOMOGENEOUS_NODE_TYPE], + input_nodes=train_node_ids[DEFAULT_HOMOGENEOUS_NODE_TYPE], # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. sampler_options=PPRSamplerOptions( alpha=_TEST_ALPHA, eps=_TEST_EPS, From b08f0e51cbd6338e09f75d534567c337bc606dfd Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 17:18:00 +0000 Subject: [PATCH 18/32] Remove stale PPR test ty ignore --- tests/unit/distributed/dist_ppr_sampler_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/distributed/dist_ppr_sampler_test.py b/tests/unit/distributed/dist_ppr_sampler_test.py index b8f324fad..c01a6da10 100644 --- a/tests/unit/distributed/dist_ppr_sampler_test.py +++ b/tests/unit/distributed/dist_ppr_sampler_test.py @@ -516,7 +516,7 @@ def _run_ppr_ablp_loader_correctness_check( loader = DistABLPLoader( dataset=dataset, num_neighbors=[], # Unused by PPR sampler; required by interface - input_nodes=(USER, train_node_ids[USER]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + input_nodes=(USER, train_node_ids[USER]), supervision_edge_type=USER_TO_STORY, sampler_options=PPRSamplerOptions( alpha=alpha, From a6eedd1a727839273f08e97f0802f76085f93e66 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 18:06:06 +0000 Subject: [PATCH 19/32] Use union shape for PPR degree tensors --- gigl/distributed/base_dist_loader.py | 10 ++ gigl/distributed/dist_dataset.py | 34 +++-- gigl/distributed/dist_ppr_sampler.py | 40 +++-- gigl/distributed/dist_sampling_producer.py | 6 +- .../shared_dist_sampling_producer.py | 19 ++- gigl/distributed/utils/degree.py | 143 ++++++++++-------- gigl/distributed/utils/dist_sampler.py | 2 +- tests/unit/distributed/utils/degree_test.py | 26 ++-- 8 files changed, 175 insertions(+), 105 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index b8f6f2c87..c298aa0d2 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -466,6 +466,16 @@ def create_mp_producer( channel = BaseDistLoader.create_colocated_channel(worker_options) if isinstance(sampler_options, PPRSamplerOptions): degree_tensors = dataset.degree_tensor + if isinstance(degree_tensors, dict): + logger.info( + f"Pre-computed degree tensors for PPR sampling across " + f"{len(degree_tensors)} node types." + ) + else: + logger.info( + f"Pre-computed degree tensor for PPR sampling with " + f"{degree_tensors.size(0)} nodes." + ) else: degree_tensors = None return DistSamplingProducer( diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index 31c97410a..181c2c7d9 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -80,7 +80,9 @@ def __init__( edge_feature_info: Optional[ Union[FeatureInfo, dict[EdgeType, FeatureInfo]] ] = None, - degree_tensor: Optional[dict[NodeType, torch.Tensor]] = None, + degree_tensor: Optional[ + Union[torch.Tensor, dict[NodeType, torch.Tensor]] + ] = None, max_labels_per_anchor_node: Optional[int] = None, edge_weights: Optional[ Union[torch.Tensor, dict[EdgeType, torch.Tensor]] @@ -109,7 +111,7 @@ def __init__( Note this will be None in the homogeneous case if the data has no node features, or will only contain node types with node features in the heterogeneous case. edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Dimension of edge features and its data type, will be a dict if heterogeneous. Note this will be None in the homogeneous case if the data has no edge features, or will only contain edge types with edge features in the heterogeneous case. - degree_tensor: Optional[dict[NodeType, torch.Tensor]]: Pre-computed degree tensor keyed by node type. Lazily computed on first access via the degree_tensor property. + degree_tensor: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property. max_labels_per_anchor_node (Optional[int]): Optional cap for how many labels to materialize per anchor node for ABLP label fetching. edge_weights: Per-edge sampling weights for this rank's partition. @@ -149,7 +151,9 @@ def __init__( self._node_feature_info = node_feature_info self._edge_feature_info = edge_feature_info - self._degree_tensor: Optional[dict[NodeType, torch.Tensor]] = degree_tensor + self._degree_tensor: Optional[ + Union[torch.Tensor, dict[NodeType, torch.Tensor]] + ] = degree_tensor self._max_labels_per_anchor_node = max_labels_per_anchor_node self._edge_weights: Optional[ Union[torch.Tensor, dict[EdgeType, torch.Tensor]] @@ -311,15 +315,15 @@ def edge_feature_info( @property def degree_tensor( self, - ) -> dict[NodeType, torch.Tensor]: + ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: """ - Lazily compute and return the total degree tensor per node type. + Lazily compute and return degree tensors for the graph. On first access, computes node degrees from the graph partition and uses - all-reduce to aggregate across all machines. Degrees are summed across - all incident edge types per anchor node type before the all-reduce, so - the per-edge-type tensor is never stored. Requires torch.distributed to - be initialized. + all-reduce to aggregate across all machines. For heterogeneous graphs, + degrees are summed across all incident edge types per anchor node type + before the all-reduce, so the per-edge-type tensor is never stored. + Requires torch.distributed to be initialized. Over-counting correction (for processes sharing the same data on the same machine) is handled automatically by detecting the distributed topology. @@ -327,9 +331,9 @@ def degree_tensor( The result is cached for subsequent accesses. Returns: - dict[NodeType, torch.Tensor]: Total degree tensors keyed by node type. - For homogeneous graphs the single entry uses - ``DEFAULT_HOMOGENEOUS_NODE_TYPE`` as its key. + Union[torch.Tensor, dict[NodeType, torch.Tensor]]: Degree tensor for + homogeneous graphs, or total degree tensors keyed by node type + for heterogeneous graphs. Raises: RuntimeError: If torch.distributed is not initialized. @@ -943,7 +947,7 @@ def share_ipc( Optional[Union[int, dict[NodeType, int]]], Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]], Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]], - Optional[dict[NodeType, torch.Tensor]], + Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], Optional[int], Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], ]: @@ -967,7 +971,7 @@ def share_ipc( Optional[Union[int, dict[NodeType, int]]]: Number of test nodes on the current machine. Will be a dict if heterogeneous. Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]: Node feature dim and its data type, will be a dict if heterogeneous Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Edge feature dim and its data type, will be a dict if heterogeneous - Optional[dict[NodeType, torch.Tensor]]: Degree tensors keyed by node type + Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]]: Degree tensors Optional[int]: Optional per-anchor label cap for ABLP label fetching Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Per-edge sampling weights for this rank's partition """ @@ -1256,7 +1260,7 @@ def _rebuild_distributed_dataset( Optional[ Union[FeatureInfo, dict[EdgeType, FeatureInfo]] ], # Edge feature dim and its data type - Optional[dict[NodeType, torch.Tensor]], # Degree tensors + Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], # Degree tensors Optional[int], # Optional per-anchor label cap for ABLP label fetching Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], # edge_weights ], diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 80049f305..644a1ee76 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -74,10 +74,11 @@ class DistPPRNeighborSampler(BaseDistNeighborSampler): but require more computation. Typical values: 1e-4 to 1e-6. max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores. num_neighbors_per_hop: Maximum number of neighbors to fetch per hop. - degree_tensors: Pre-computed total-degree tensors (int32), keyed by NodeType. - Must be pre-computed by the caller through - ``DistDataset.degree_tensor`` so that workers share a single - allocation rather than recomputing per-worker. + degree_tensors: Pre-computed total-degree tensors (int32). Homogeneous + graphs use a single tensor; heterogeneous graphs use tensors keyed + by NodeType. Must be pre-computed by the caller through + ``DistDataset.degree_tensor`` so workers share a single allocation + rather than recomputing per-worker. """ def __init__( @@ -87,7 +88,7 @@ def __init__( eps: float = 1e-4, max_ppr_nodes: int = 50, num_neighbors_per_hop: int = 100_000, - degree_tensors: dict[NodeType, torch.Tensor], + degree_tensors: Union[torch.Tensor, dict[NodeType, torch.Tensor]], max_fetch_iterations: Optional[int] = None, **kwargs, ): @@ -129,11 +130,9 @@ def __init__( ] self._is_homogeneous = True - # Total-degree tensors keyed by NodeType, pre-computed by the caller. - # Callers compute DistDataset.degree_tensor once in the parent process - # and place the result in shared memory so all worker processes map the - # same allocation. - self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = degree_tensors + # Normalize the public homogeneous/heterogeneous degree-tensor shape to + # the node-type keyed form used internally by PPR. + self._node_type_to_total_degree = self._normalize_degree_tensors(degree_tensors) # Build integer ID mappings for the C++ forward-push kernel. String # NodeType / EdgeType keys are only used at the Python boundary @@ -183,6 +182,27 @@ def __init__( for nt in all_node_types ] + def _normalize_degree_tensors( + self, + degree_tensors: Union[torch.Tensor, dict[NodeType, torch.Tensor]], + ) -> dict[NodeType, torch.Tensor]: + """Normalize degree tensors to the node-type keyed shape PPR uses.""" + if isinstance(degree_tensors, torch.Tensor): + if not self._is_homogeneous: + raise ValueError( + "Expected degree tensors keyed by node type for heterogeneous PPR sampling." + ) + return {DEFAULT_HOMOGENEOUS_NODE_TYPE: degree_tensors} + + missing_anchor_types = set(self._node_type_to_edge_types.keys()) - set( + degree_tensors.keys() + ) + if missing_anchor_types: + raise ValueError( + f"Missing PPR degree tensors for node types: {missing_anchor_types}" + ) + return degree_tensors + def _get_destination_type(self, edge_type: EdgeType) -> NodeType: """Get the node type at the destination end of an edge type.""" return edge_type[0] if self.edge_dir == "in" else edge_type[-1] diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 15d29a48c..7a11ea703 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -55,7 +55,7 @@ def _sampling_worker_loop( sampling_completed_worker_count, # mp.Value mp_barrier: Barrier, sampler_options: SamplerOptions, - degree_tensors: Optional[dict[NodeType, torch.Tensor]], + degree_tensors: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], ): dist_sampler = None try: @@ -180,7 +180,9 @@ def __init__( worker_options: MpDistSamplingWorkerOptions, channel: ChannelBase, sampler_options: SamplerOptions, - degree_tensors: Optional[dict[NodeType, torch.Tensor]] = None, + degree_tensors: Optional[ + Union[torch.Tensor, dict[NodeType, torch.Tensor]] + ] = None, ): super().__init__(data, sampler_input, sampling_config, worker_options, channel) self._sampler_options = sampler_options diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index c6564a39d..e8d679c7b 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -339,7 +339,7 @@ def _shared_sampling_worker_loop( event_queue: mp.Queue, mp_barrier: Barrier, sampler_options: SamplerOptions, - degree_tensors: Optional[dict[NodeType, torch.Tensor]], + degree_tensors: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], ) -> None: """Run one shared graph-store worker that schedules many input channels. @@ -836,7 +836,7 @@ def __init__( worker_options: RemoteDistSamplingWorkerOptions, sampling_config: SamplingConfig, sampler_options: SamplerOptions, - degree_tensors: Optional[dict[NodeType, torch.Tensor]], + degree_tensors: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], ) -> None: """Initialize the shared sampling backend. @@ -874,7 +874,20 @@ def __init__( ) # Move degree tensors to shared memory so all spawned workers map the # same allocation instead of each pickling a private copy. - self._degree_tensors: Optional[dict[NodeType, torch.Tensor]] = degree_tensors + self._degree_tensors: Optional[ + Union[torch.Tensor, dict[NodeType, torch.Tensor]] + ] = degree_tensors + if degree_tensors is not None: + if isinstance(degree_tensors, dict): + logger.info( + f"Pre-computed degree tensors for PPR sampling across " + f"{len(degree_tensors)} node types." + ) + else: + logger.info( + f"Pre-computed degree tensor for PPR sampling with " + f"{degree_tensors.size(0)} nodes." + ) share_memory(self._degree_tensors) def init_backend(self) -> None: diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index d33ec74f0..01a58fcaa 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -5,9 +5,10 @@ and aggregate them across distributed machines. Degrees are computed from the CSR (Compressed Sparse Row) topology stored in GraphLearn-Torch Graph objects. -Degrees are accumulated per anchor node type (summing across all edge types -incident to that node type) before the distributed all-reduce, so callers -receive ``dict[NodeType, torch.Tensor]`` directly with no further conversion. +For homogeneous graphs, callers receive a single ``torch.Tensor``. For +heterogeneous graphs, degrees are accumulated per anchor node type (summing +across all edge types incident to that node type) before the distributed +all-reduce, so callers receive ``dict[NodeType, torch.Tensor]``. Requirements ============ @@ -34,7 +35,7 @@ from gigl.common.logger import Logger from gigl.distributed.utils.device import get_device_from_process_group from gigl.distributed.utils.networking import get_internal_ip_from_all_ranks -from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, is_label_edge_type +from gigl.types.graph import is_label_edge_type logger = Logger() @@ -42,7 +43,7 @@ def compute_and_broadcast_degree_tensor( graph: Union[Graph, dict[EdgeType, Graph]], edge_dir: str, -) -> dict[NodeType, torch.Tensor]: +) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: """Compute node degrees from a graph and aggregate across all machines. For each non-label edge type, degrees are derived from the CSR row pointers @@ -63,10 +64,10 @@ def compute_and_broadcast_degree_tensor( end of each edge is the anchor node type for degree accumulation. Returns: - dict[NodeType, torch.Tensor]: Aggregated degree tensors keyed by node - type. For homogeneous graphs the single entry uses - ``DEFAULT_HOMOGENEOUS_NODE_TYPE`` as its key. Values are int32 - tensors of shape ``[num_nodes_of_that_type]``. + Union[torch.Tensor, dict[NodeType, torch.Tensor]]: Aggregated degree + tensors. For homogeneous graphs, returns an int32 tensor of shape + ``[num_nodes]``. For heterogeneous graphs, returns int32 tensors + keyed by node type with shape ``[num_nodes_of_that_type]``. Raises: RuntimeError: If torch.distributed is not initialized. @@ -77,40 +78,43 @@ def compute_and_broadcast_degree_tensor( "compute_and_broadcast_degree_tensor requires torch.distributed to be initialized." ) - local_dict: dict[NodeType, torch.Tensor] = {} - if isinstance(graph, Graph): topo = graph.topo if topo is None or topo.indptr is None: raise ValueError("Topology/indptr not available for graph.") - local_dict[DEFAULT_HOMOGENEOUS_NODE_TYPE] = _compute_degrees_from_indptr( - topo.indptr - ) - else: - for edge_type, edge_graph in graph.items(): - if is_label_edge_type(edge_type): - continue - anchor_type: NodeType = edge_type[-1] if edge_dir == "in" else edge_type[0] - topo = edge_graph.topo - if topo is None or topo.indptr is None: - logger.warning( - f"Topology/indptr not available for edge type {edge_type}, using empty tensor." - ) - degrees = torch.empty(0, dtype=torch.int32) - else: - degrees = _compute_degrees_from_indptr(topo.indptr) - - if anchor_type in local_dict: - existing = local_dict[anchor_type] - max_len = max(len(existing), len(degrees)) - summed = _pad_to_size(existing, max_len).to(torch.int64) - summed[: len(degrees)] += degrees.to(torch.int64) - local_dict[anchor_type] = summed.to(torch.int32) - else: - local_dict[anchor_type] = degrees + result = _all_reduce_degree_tensor(_compute_degrees_from_indptr(topo.indptr)) + if result.numel() > 0: + logger.info( + f"{result.size(0)} nodes, max={result.max().item()}, min={result.min().item()}" + ) + else: + logger.info("Graph contained 0 nodes when computing degrees") + return result - result = _all_reduce_degrees(local_dict) + local_dict: dict[NodeType, torch.Tensor] = {} + for edge_type, edge_graph in graph.items(): + if is_label_edge_type(edge_type): + continue + anchor_type: NodeType = edge_type[-1] if edge_dir == "in" else edge_type[0] + topo = edge_graph.topo + if topo is None or topo.indptr is None: + logger.warning( + f"Topology/indptr not available for edge type {edge_type}, using empty tensor." + ) + degrees = torch.empty(0, dtype=torch.int32) + else: + degrees = _compute_degrees_from_indptr(topo.indptr) + + if anchor_type in local_dict: + existing = local_dict[anchor_type] + max_len = max(len(existing), len(degrees)) + summed = _pad_to_size(existing, max_len).to(torch.int64) + summed[: len(degrees)] += degrees.to(torch.int64) + local_dict[anchor_type] = summed.to(torch.int32) + else: + local_dict[anchor_type] = degrees + result = _all_reduce_degrees(local_dict) for node_type, degrees in result.items(): if degrees.numel() > 0: logger.info( @@ -142,6 +146,43 @@ def _compute_degrees_from_indptr(indptr: torch.Tensor) -> torch.Tensor: return (indptr[1:] - indptr[:-1]).to(torch.int32) +def _get_degree_reduce_context() -> tuple[int, torch.device]: + """Return local-world-size correction factor and all-reduce device.""" + if not torch.distributed.is_initialized(): + raise RuntimeError( + "_all_reduce_degrees requires torch.distributed to be initialized." + ) + + all_ips = get_internal_ip_from_all_ranks() + my_rank = torch.distributed.get_rank() + my_ip = all_ips[my_rank] + local_world_size = Counter(all_ips)[my_ip] + device = get_device_from_process_group() + return local_world_size, device + + +def _all_reduce_single_degree_tensor( + tensor: torch.Tensor, + local_world_size: int, + device: torch.device, +) -> torch.Tensor: + """All-reduce a single tensor with size sync and over-counting correction.""" + local_size = torch.tensor([tensor.size(0)], dtype=torch.long, device=device) + torch.distributed.all_reduce(local_size, op=torch.distributed.ReduceOp.MAX) + max_size = int(local_size.item()) + + padded = _pad_to_size(tensor, max_size).to(torch.int64).to(device) + torch.distributed.all_reduce(padded, op=torch.distributed.ReduceOp.SUM) + + return (padded // local_world_size).to(torch.int32).cpu() + + +def _all_reduce_degree_tensor(tensor: torch.Tensor) -> torch.Tensor: + """All-reduce a homogeneous degree tensor across ranks.""" + local_world_size, device = _get_degree_reduce_context() + return _all_reduce_single_degree_tensor(tensor, local_world_size, device) + + def _all_reduce_degrees( local_degrees: dict[NodeType, torch.Tensor], ) -> dict[NodeType, torch.Tensor]: @@ -172,7 +213,6 @@ def _all_reduce_degrees( Args: local_degrees: Dict mapping NodeType to local degree tensors. - All partitions must have entries for all node types. Returns: Aggregated degree tensors keyed by NodeType. @@ -180,30 +220,11 @@ def _all_reduce_degrees( Raises: RuntimeError: If torch.distributed is not initialized. """ - if not torch.distributed.is_initialized(): - raise RuntimeError( - "_all_reduce_degrees requires torch.distributed to be initialized." - ) - - all_ips = get_internal_ip_from_all_ranks() - my_rank = torch.distributed.get_rank() - my_ip = all_ips[my_rank] - local_world_size = Counter(all_ips)[my_ip] - - device = get_device_from_process_group() - - def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor: - """All-reduce a single tensor with size sync and over-counting correction.""" - local_size = torch.tensor([tensor.size(0)], dtype=torch.long, device=device) - torch.distributed.all_reduce(local_size, op=torch.distributed.ReduceOp.MAX) - max_size = int(local_size.item()) - - padded = _pad_to_size(tensor, max_size).to(torch.int64).to(device) - torch.distributed.all_reduce(padded, op=torch.distributed.ReduceOp.SUM) - - return (padded // local_world_size).to(torch.int32).cpu() + local_world_size, device = _get_degree_reduce_context() result: dict[NodeType, torch.Tensor] = {} for node_type in sorted(local_degrees.keys()): - result[node_type] = reduce_tensor(local_degrees[node_type]) + result[node_type] = _all_reduce_single_degree_tensor( + local_degrees[node_type], local_world_size, device + ) return result diff --git a/gigl/distributed/utils/dist_sampler.py b/gigl/distributed/utils/dist_sampler.py index db5dba1af..55fdb301e 100644 --- a/gigl/distributed/utils/dist_sampler.py +++ b/gigl/distributed/utils/dist_sampler.py @@ -35,7 +35,7 @@ def create_dist_sampler( worker_options: Union[MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions], channel: ChannelBase, sampler_options: SamplerOptions, - degree_tensors: Optional[dict[NodeType, torch.Tensor]], + degree_tensors: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], current_device: torch.device, ) -> SamplerRuntime: """Create a GiGL sampler runtime for one channel on one worker. diff --git a/tests/unit/distributed/utils/degree_test.py b/tests/unit/distributed/utils/degree_test.py index 5bd84651e..85eeaa0e2 100644 --- a/tests/unit/distributed/utils/degree_test.py +++ b/tests/unit/distributed/utils/degree_test.py @@ -12,7 +12,6 @@ compute_and_broadcast_degree_tensor, ) from gigl.src.common.types.graph_data import EdgeType, NodeType -from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE from tests.test_assets.distributed.test_dataset import ( DEFAULT_HETEROGENEOUS_EDGE_INDICES, DEFAULT_HOMOGENEOUS_EDGE_INDEX, @@ -107,10 +106,10 @@ def test_homogeneous_graph(self): assert dataset.graph is not None result = compute_and_broadcast_degree_tensor(dataset.graph, dataset.edge_dir) - self.assertEqual(set(result.keys()), {DEFAULT_HOMOGENEOUS_NODE_TYPE}) + assert isinstance(result, torch.Tensor) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assertEqual(result[DEFAULT_HOMOGENEOUS_NODE_TYPE].shape[0], num_nodes) - self.assert_tensor_equality(result[DEFAULT_HOMOGENEOUS_NODE_TYPE], expected) + self.assertEqual(result.shape[0], num_nodes) + self.assert_tensor_equality(result, expected) def test_heterogeneous_graph(self): """Test degree computation for a heterogeneous graph.""" @@ -120,6 +119,7 @@ def test_heterogeneous_graph(self): assert dataset.graph is not None result = compute_and_broadcast_degree_tensor(dataset.graph, dataset.edge_dir) + assert not isinstance(result, torch.Tensor) expected = _compute_expected_total_degrees_by_node_type( edge_indices=edge_indices, edge_dir=dataset.edge_dir, @@ -161,6 +161,7 @@ def test_heterogeneous_graph_with_missing_topology(self): result = compute_and_broadcast_degree_tensor(dataset.graph, dataset.edge_dir) + assert not isinstance(result, torch.Tensor) expected_node_types = { _get_anchor_node_type(edge_type, dataset.edge_dir) for edge_type in edge_types @@ -189,7 +190,7 @@ def _run_local_world_size_correction_homogeneous( world_size: int, init_method: str, edge_index: torch.Tensor, - expected_degrees: dict[NodeType, torch.Tensor], + expected_degrees: torch.Tensor, ) -> None: """Worker function for multi-process local_world_size correction test (homogeneous).""" dist.init_process_group( @@ -203,9 +204,8 @@ def _run_local_world_size_correction_homogeneous( assert dataset.graph is not None result = compute_and_broadcast_degree_tensor(dataset.graph, dataset.edge_dir) - assert set(result.keys()) == set(expected_degrees.keys()) - for node_type, expected in expected_degrees.items(): - assert_tensor_equality(result[node_type], expected) + assert isinstance(result, torch.Tensor) + assert_tensor_equality(result, expected_degrees) finally: dist.destroy_process_group() @@ -229,6 +229,7 @@ def _run_local_world_size_correction_heterogeneous( assert dataset.graph is not None result = compute_and_broadcast_degree_tensor(dataset.graph, dataset.edge_dir) + assert not isinstance(result, torch.Tensor) assert set(result.keys()) == set(expected_degrees.keys()) for node_type, expected in expected_degrees.items(): assert_tensor_equality(result[node_type], expected) @@ -250,9 +251,7 @@ def test_local_world_size_correction_homogeneous(self): num_nodes = int(edge_index.max().item() + 1) raw_degrees = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - expected_degrees = { - DEFAULT_HOMOGENEOUS_NODE_TYPE: raw_degrees - } # After correction: (2*raw) / 2 = raw + expected_degrees = raw_degrees # After correction: (2*raw) / 2 = raw init_method = get_process_group_init_method() mp.spawn( @@ -300,9 +299,9 @@ def test_degree_tensor_homogeneous(self): dataset = create_homogeneous_dataset(edge_index=edge_index) result = dataset.degree_tensor - self.assertEqual(set(result.keys()), {DEFAULT_HOMOGENEOUS_NODE_TYPE}) + assert isinstance(result, torch.Tensor) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assert_tensor_equality(result[DEFAULT_HOMOGENEOUS_NODE_TYPE], expected) + self.assert_tensor_equality(result, expected) def test_degree_tensor_caches_result(self): """Test that degree_tensor property caches the result.""" @@ -320,6 +319,7 @@ def test_degree_tensor_heterogeneous(self): result = dataset.degree_tensor + assert not isinstance(result, torch.Tensor) expected = _compute_expected_total_degrees_by_node_type( edge_indices=edge_indices, edge_dir=dataset.edge_dir, From 68ab0f20bc0f99af412c2082395bd365fd408808 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 18:10:00 +0000 Subject: [PATCH 20/32] Restore useful degree computation comments --- gigl/distributed/utils/degree.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index 01a58fcaa..5e543874c 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -78,10 +78,13 @@ def compute_and_broadcast_degree_tensor( "compute_and_broadcast_degree_tensor requires torch.distributed to be initialized." ) + # Compute local degrees from graph topology. if isinstance(graph, Graph): topo = graph.topo if topo is None or topo.indptr is None: raise ValueError("Topology/indptr not available for graph.") + + # Homogeneous graphs keep the usual GiGL shape: a single tensor. result = _all_reduce_degree_tensor(_compute_degrees_from_indptr(topo.indptr)) if result.numel() > 0: logger.info( @@ -93,6 +96,8 @@ def compute_and_broadcast_degree_tensor( local_dict: dict[NodeType, torch.Tensor] = {} for edge_type, edge_graph in graph.items(): + # Label edge types are supervision edges and should not contribute to + # node degree for traversal algorithms like PPR. if is_label_edge_type(edge_type): continue anchor_type: NodeType = edge_type[-1] if edge_dir == "in" else edge_type[0] @@ -114,6 +119,7 @@ def compute_and_broadcast_degree_tensor( else: local_dict[anchor_type] = degrees + # All-reduce across ranks after local per-node-type aggregation. result = _all_reduce_degrees(local_dict) for node_type, degrees in result.items(): if degrees.numel() > 0: @@ -153,10 +159,13 @@ def _get_degree_reduce_context() -> tuple[int, torch.device]: "_all_reduce_degrees requires torch.distributed to be initialized." ) + # Compute local_world_size: number of processes on the same machine sharing data. all_ips = get_internal_ip_from_all_ranks() my_rank = torch.distributed.get_rank() my_ip = all_ips[my_rank] local_world_size = Counter(all_ips)[my_ip] + + # NCCL backend requires CUDA tensors; Gloo works with CPU. device = get_device_from_process_group() return local_world_size, device @@ -167,13 +176,17 @@ def _all_reduce_single_degree_tensor( device: torch.device, ) -> torch.Tensor: """All-reduce a single tensor with size sync and over-counting correction.""" + # Synchronize max size across all ranks. local_size = torch.tensor([tensor.size(0)], dtype=torch.long, device=device) torch.distributed.all_reduce(local_size, op=torch.distributed.ReduceOp.MAX) max_size = int(local_size.item()) + # Pad, convert to int64 for all_reduce, and move to the process-group device. padded = _pad_to_size(tensor, max_size).to(torch.int64).to(device) torch.distributed.all_reduce(padded, op=torch.distributed.ReduceOp.SUM) + # Correct for over-counting and move back to CPU. We keep int32 so high-degree + # nodes do not saturate at int16. return (padded // local_world_size).to(torch.int32).cpu() @@ -222,6 +235,7 @@ def _all_reduce_degrees( """ local_world_size, device = _get_degree_reduce_context() + # Heterogeneous case: all-reduce each node type in deterministic order. result: dict[NodeType, torch.Tensor] = {} for node_type in sorted(local_degrees.keys()): result[node_type] = _all_reduce_single_degree_tensor( From e71ccdb8e0bd2a589842440a349f69890b5c83db Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 18:17:42 +0000 Subject: [PATCH 21/32] Remove sampler diagnostic wrapper --- gigl/distributed/base_sampler.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/gigl/distributed/base_sampler.py b/gigl/distributed/base_sampler.py index e8e6f9e77..986ba5d58 100644 --- a/gigl/distributed/base_sampler.py +++ b/gigl/distributed/base_sampler.py @@ -1,4 +1,3 @@ -import logging from collections import defaultdict from dataclasses import dataclass from typing import Optional, Union @@ -214,15 +213,11 @@ async def _send_adapter( Copied from ``graphlearn_torch.distributed.DistNeighborSampler._send_adapter`` (GLT 0.2.4) with the single change of ``_colloate_fn`` → ``_collate_fn``. """ - try: - sampler_output = await async_func(*args, **kwargs) - res = await self._collate_fn(sampler_output) - if self.channel is None: - return res - self.channel.send(res) - except Exception: - logging.exception("sampler task failed") - raise + sampler_output = await async_func(*args, **kwargs) + res = await self._collate_fn(sampler_output) + if self.channel is None: + return res + self.channel.send(res) return None async def _collate_fn( From f76e548b7ed280d3f1eb71bfd450aa9258de7397 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 18:24:34 +0000 Subject: [PATCH 22/32] Simplify degree all-reduce helper --- gigl/distributed/utils/degree.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index 5e543874c..54f21df74 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -25,7 +25,7 @@ """ from collections import Counter -from typing import Union +from typing import Union, overload import torch from graphlearn_torch.data import Graph @@ -85,7 +85,7 @@ def compute_and_broadcast_degree_tensor( raise ValueError("Topology/indptr not available for graph.") # Homogeneous graphs keep the usual GiGL shape: a single tensor. - result = _all_reduce_degree_tensor(_compute_degrees_from_indptr(topo.indptr)) + result = _all_reduce_degrees(_compute_degrees_from_indptr(topo.indptr)) if result.numel() > 0: logger.info( f"{result.size(0)} nodes, max={result.max().item()}, min={result.min().item()}" @@ -190,15 +190,19 @@ def _all_reduce_single_degree_tensor( return (padded // local_world_size).to(torch.int32).cpu() -def _all_reduce_degree_tensor(tensor: torch.Tensor) -> torch.Tensor: - """All-reduce a homogeneous degree tensor across ranks.""" - local_world_size, device = _get_degree_reduce_context() - return _all_reduce_single_degree_tensor(tensor, local_world_size, device) +@overload +def _all_reduce_degrees( + local_degrees: dict[NodeType, torch.Tensor], +) -> dict[NodeType, torch.Tensor]: ... + + +@overload +def _all_reduce_degrees(local_degrees: torch.Tensor) -> torch.Tensor: ... def _all_reduce_degrees( - local_degrees: dict[NodeType, torch.Tensor], -) -> dict[NodeType, torch.Tensor]: + local_degrees: Union[torch.Tensor, dict[NodeType, torch.Tensor]], +) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: """All-reduce degree tensors across ranks. Moves tensors to GPU for the all-reduce if using NCCL backend (which @@ -225,16 +229,20 @@ def _all_reduce_degrees( over-counting. Args: - local_degrees: Dict mapping NodeType to local degree tensors. + local_degrees: Either a homogeneous degree tensor or a dict mapping + NodeType to local degree tensors. Returns: - Aggregated degree tensors keyed by NodeType. + Aggregated degree tensors matching the input shape. Raises: RuntimeError: If torch.distributed is not initialized. """ local_world_size, device = _get_degree_reduce_context() + if isinstance(local_degrees, torch.Tensor): + return _all_reduce_single_degree_tensor(local_degrees, local_world_size, device) + # Heterogeneous case: all-reduce each node type in deterministic order. result: dict[NodeType, torch.Tensor] = {} for node_type in sorted(local_degrees.keys()): From 23ee86f3ff88d4683f847ba08770b61e1368ac63 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 18:32:34 +0000 Subject: [PATCH 23/32] Document degree tensor assumptions --- gigl/distributed/utils/degree.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index 54f21df74..6306b09de 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -22,6 +22,14 @@ Over-counting correction is handled automatically in _all_reduce_degrees by detecting how many processes share the same machine (and thus the same data). + +Heterogeneous partitioned graphs are expected to materialize all registered +non-label edge types on every rank, even when a rank has no local edges for a +type. This keeps the per-node-type all-reduce order consistent across ranks. + +Degree tensors are stored as int32 to match the PPR sampler's needs while +keeping memory usage low. This assumes individual node degrees stay below the +int32 maximum, which is far above expected node degrees. """ from collections import Counter From 3b3497dbabdd78a6b0ed5c7724fd34c5e6b67007 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 20:10:21 +0000 Subject: [PATCH 24/32] Address PPR degree review comments --- gigl/distributed/dist_ppr_sampler.py | 10 +- .../shared_dist_sampling_producer.py | 4 +- gigl/distributed/utils/degree.py | 7 +- gigl/distributed/utils/neighborloader.py | 1 + tests/unit/distributed/utils/degree_test.py | 107 ++++-------------- 5 files changed, 35 insertions(+), 94 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 644a1ee76..cb007c291 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -130,9 +130,11 @@ def __init__( ] self._is_homogeneous = True - # Normalize the public homogeneous/heterogeneous degree-tensor shape to + # Convert the public homogeneous/heterogeneous degree-tensor shape to # the node-type keyed form used internally by PPR. - self._node_type_to_total_degree = self._normalize_degree_tensors(degree_tensors) + self._node_type_to_total_degree = self._convert_degree_tensors_to_dict( + degree_tensors + ) # Build integer ID mappings for the C++ forward-push kernel. String # NodeType / EdgeType keys are only used at the Python boundary @@ -182,11 +184,11 @@ def __init__( for nt in all_node_types ] - def _normalize_degree_tensors( + def _convert_degree_tensors_to_dict( self, degree_tensors: Union[torch.Tensor, dict[NodeType, torch.Tensor]], ) -> dict[NodeType, torch.Tensor]: - """Normalize degree tensors to the node-type keyed shape PPR uses.""" + """Convert degree tensors to the node-type keyed shape PPR uses.""" if isinstance(degree_tensors, torch.Tensor): if not self._is_homogeneous: raise ValueError( diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index e8d679c7b..6faf56721 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -873,7 +873,9 @@ def __init__( set ) # Move degree tensors to shared memory so all spawned workers map the - # same allocation instead of each pickling a private copy. + # same allocation instead of each pickling a private copy. DistDataset + # also shares cached degree tensors through IPC, but graph-store PPR can + # compute them lazily when this backend is created. self._degree_tensors: Optional[ Union[torch.Tensor, dict[NodeType, torch.Tensor]] ] = degree_tensors diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index 6306b09de..f4aa7ce32 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -27,9 +27,10 @@ non-label edge types on every rank, even when a rank has no local edges for a type. This keeps the per-node-type all-reduce order consistent across ranks. -Degree tensors are stored as int32 to match the PPR sampler's needs while -keeping memory usage low. This assumes individual node degrees stay below the -int32 maximum, which is far above expected node degrees. +Degree tensors are stored as int32 because the PPR C++ sampler expects standard +int32/int64 integer tensors; int32 keeps memory usage lower than int64. This +assumes individual node degrees stay below the int32 maximum, which is far +above expected node degrees. """ from collections import Counter diff --git a/gigl/distributed/utils/neighborloader.py b/gigl/distributed/utils/neighborloader.py index 570fca93b..b24de6733 100644 --- a/gigl/distributed/utils/neighborloader.py +++ b/gigl/distributed/utils/neighborloader.py @@ -367,6 +367,7 @@ def attach_ppr_outputs( edge_type = next(iter(ppr_edge_indices)) data.edge_index = ppr_edge_indices[edge_type] data.edge_attr = ppr_weights[edge_type] + # Homogeneous Data has no per-edge-type stores; the PPR edges are attached. return for edge_type, edge_index in ppr_edge_indices.items(): diff --git a/tests/unit/distributed/utils/degree_test.py b/tests/unit/distributed/utils/degree_test.py index 85eeaa0e2..809590093 100644 --- a/tests/unit/distributed/utils/degree_test.py +++ b/tests/unit/distributed/utils/degree_test.py @@ -1,5 +1,3 @@ -from typing import Literal - import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -15,6 +13,10 @@ from tests.test_assets.distributed.test_dataset import ( DEFAULT_HETEROGENEOUS_EDGE_INDICES, DEFAULT_HOMOGENEOUS_EDGE_INDEX, + STORY, + STORY_TO_USER, + USER, + USER_TO_STORY, create_heterogeneous_dataset, create_homogeneous_dataset, ) @@ -37,48 +39,6 @@ def _compute_expected_degrees_from_edge_index( return degrees -def _get_anchor_node_type( - edge_type: EdgeType, edge_dir: Literal["in", "out"] -) -> NodeType: - """Return the node type whose CSR rows define traversable degrees.""" - return edge_type.dst_node_type if edge_dir == "in" else edge_type.src_node_type - - -def _compute_expected_total_degrees_by_node_type( - edge_indices: dict[EdgeType, torch.Tensor], - edge_dir: Literal["in", "out"], -) -> dict[NodeType, torch.Tensor]: - """Compute total degrees keyed by anchor node type.""" - node_axis = 1 if edge_dir == "in" else 0 - expected: dict[NodeType, torch.Tensor] = {} - for edge_type, edge_index in edge_indices.items(): - anchor_node_type = _get_anchor_node_type(edge_type, edge_dir) - num_nodes = ( - int(edge_index[node_axis].max().item() + 1) - if edge_index.shape[1] > 0 - else 0 - ) - degrees = _compute_expected_degrees_from_edge_index( - edge_index=edge_index, - num_nodes=num_nodes, - node_axis=node_axis, - ) - - if anchor_node_type not in expected: - expected[anchor_node_type] = degrees - continue - - max_len = max(expected[anchor_node_type].numel(), degrees.numel()) - summed_degrees = torch.zeros(max_len, dtype=torch.int64) - summed_degrees[: expected[anchor_node_type].numel()] += expected[ - anchor_node_type - ].to(torch.int64) - summed_degrees[: degrees.numel()] += degrees.to(torch.int64) - expected[anchor_node_type] = summed_degrees.to(torch.int32) - - return expected - - class TestDegreeComputation(TestCase): """Tests for degree computation with torch.distributed initialized. @@ -120,10 +80,10 @@ def test_heterogeneous_graph(self): result = compute_and_broadcast_degree_tensor(dataset.graph, dataset.edge_dir) assert not isinstance(result, torch.Tensor) - expected = _compute_expected_total_degrees_by_node_type( - edge_indices=edge_indices, - edge_dir=dataset.edge_dir, - ) + expected = { + USER: torch.ones(5, dtype=torch.int32), + STORY: torch.ones(5, dtype=torch.int32), + } self.assertEqual(set(result.keys()), set(expected.keys())) for node_type, expected_degrees in expected.items(): @@ -142,47 +102,22 @@ def test_heterogeneous_graph_with_missing_topology(self): assert dataset.graph is not None assert isinstance(dataset.graph, dict) - # Get edge types from the dataset - edge_types = list(dataset.graph.keys()) - - edge_type_with_topo = edge_types[0] - edge_type_without_topo = edge_types[1] - - # Save the original topology for computing expected degrees - original_graph = dataset.graph[edge_type_with_topo] + original_graph = dataset.graph[USER_TO_STORY] assert original_graph.topo is not None - expected_degrees = _compute_expected_total_degrees_by_node_type( - edge_indices={edge_type_with_topo: edge_indices[edge_type_with_topo]}, - edge_dir=dataset.edge_dir, - ) # Manually set one graph's topology to None to test the edge case - dataset.graph[edge_type_without_topo].topo = None + dataset.graph[STORY_TO_USER].topo = None result = compute_and_broadcast_degree_tensor(dataset.graph, dataset.edge_dir) assert not isinstance(result, torch.Tensor) - expected_node_types = { - _get_anchor_node_type(edge_type, dataset.edge_dir) - for edge_type in edge_types - } - self.assertEqual(set(result.keys()), expected_node_types) + self.assertEqual(set(result.keys()), {USER, STORY}) # Edge type with topology should have computed degrees - node_type_with_topo = _get_anchor_node_type( - edge_type=edge_type_with_topo, - edge_dir=dataset.edge_dir, - ) - self.assert_tensor_equality( - result[node_type_with_topo], expected_degrees[node_type_with_topo] - ) + self.assert_tensor_equality(result[USER], torch.ones(5, dtype=torch.int32)) # Edge type without topology should have empty tensor - node_type_without_topo = _get_anchor_node_type( - edge_type=edge_type_without_topo, - edge_dir=dataset.edge_dir, - ) - self.assertEqual(result[node_type_without_topo].numel(), 0) + self.assertEqual(result[STORY].numel(), 0) def _run_local_world_size_correction_homogeneous( @@ -264,10 +199,10 @@ def test_local_world_size_correction_heterogeneous(self): """Test over-counting correction for heterogeneous graphs with 2 processes.""" edge_indices = DEFAULT_HETEROGENEOUS_EDGE_INDICES - expected_degrees = _compute_expected_total_degrees_by_node_type( - edge_indices=edge_indices, - edge_dir="out", - ) + expected_degrees = { + USER: torch.ones(5, dtype=torch.int32), + STORY: torch.ones(5, dtype=torch.int32), + } init_method = get_process_group_init_method() mp.spawn( @@ -320,10 +255,10 @@ def test_degree_tensor_heterogeneous(self): result = dataset.degree_tensor assert not isinstance(result, torch.Tensor) - expected = _compute_expected_total_degrees_by_node_type( - edge_indices=edge_indices, - edge_dir=dataset.edge_dir, - ) + expected = { + USER: torch.ones(5, dtype=torch.int32), + STORY: torch.ones(5, dtype=torch.int32), + } self.assertEqual(set(result.keys()), set(expected.keys())) for node_type, expected_degrees in expected.items(): From 5ac1c6342448a3fd4c564a097453d95a92ec5f85 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 20:28:21 +0000 Subject: [PATCH 25/32] Address PPR degree memory review comments --- gigl/distributed/dist_dataset.py | 4 ++++ .../shared_dist_sampling_producer.py | 6 ----- gigl/distributed/utils/degree.py | 24 ++++++++++++------- tests/unit/distributed/utils/degree_test.py | 13 ++++++++++ 4 files changed, 32 insertions(+), 15 deletions(-) diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index 181c2c7d9..1ba437e73 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -330,6 +330,9 @@ def degree_tensor( The result is cached for subsequent accesses. + The cached degree tensor is moved to shared memory before being returned, + so spawned workers can share a single allocation. + Returns: Union[torch.Tensor, dict[NodeType, torch.Tensor]]: Degree tensor for homogeneous graphs, or total degree tensors keyed by node type @@ -346,6 +349,7 @@ def degree_tensor( self._degree_tensor = compute_and_broadcast_degree_tensor( self.graph, self._edge_dir ) + share_memory(entity=self._degree_tensor) return self._degree_tensor @property diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index 6faf56721..4823650cf 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -103,7 +103,6 @@ SamplerRuntime, create_dist_sampler, ) -from gigl.utils.share_memory import share_memory logger = Logger() @@ -872,10 +871,6 @@ def __init__( self._completed_workers: defaultdict[tuple[int, int], set[int]] = defaultdict( set ) - # Move degree tensors to shared memory so all spawned workers map the - # same allocation instead of each pickling a private copy. DistDataset - # also shares cached degree tensors through IPC, but graph-store PPR can - # compute them lazily when this backend is created. self._degree_tensors: Optional[ Union[torch.Tensor, dict[NodeType, torch.Tensor]] ] = degree_tensors @@ -890,7 +885,6 @@ def __init__( f"Pre-computed degree tensor for PPR sampling with " f"{degree_tensors.size(0)} nodes." ) - share_memory(self._degree_tensors) def init_backend(self) -> None: """Initialize worker processes once for this backend. diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index f4aa7ce32..102c3c712 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -28,13 +28,12 @@ type. This keeps the per-node-type all-reduce order consistent across ranks. Degree tensors are stored as int32 because the PPR C++ sampler expects standard -int32/int64 integer tensors; int32 keeps memory usage lower than int64. This -assumes individual node degrees stay below the int32 maximum, which is far -above expected node degrees. +int32/int64 integer tensors; int32 keeps memory usage lower than int64. Values +above the int32 maximum are clamped before casting to avoid wraparound. """ from collections import Counter -from typing import Union, overload +from typing import Final, Union, overload import torch from graphlearn_torch.data import Graph @@ -48,6 +47,8 @@ logger = Logger() +_INT32_MAX: Final[int] = torch.iinfo(torch.int32).max + def compute_and_broadcast_degree_tensor( graph: Union[Graph, dict[EdgeType, Graph]], @@ -124,7 +125,7 @@ def compute_and_broadcast_degree_tensor( max_len = max(len(existing), len(degrees)) summed = _pad_to_size(existing, max_len).to(torch.int64) summed[: len(degrees)] += degrees.to(torch.int64) - local_dict[anchor_type] = summed.to(torch.int32) + local_dict[anchor_type] = _clamp_to_int32(summed) else: local_dict[anchor_type] = degrees @@ -156,9 +157,14 @@ def _pad_to_size(tensor: torch.Tensor, target_size: int) -> torch.Tensor: return torch.cat([tensor, padding]) +def _clamp_to_int32(tensor: torch.Tensor) -> torch.Tensor: + """Clamp degree values to int32 range before converting dtype.""" + return tensor.clamp(max=_INT32_MAX).to(torch.int32) + + def _compute_degrees_from_indptr(indptr: torch.Tensor) -> torch.Tensor: """Compute degrees from CSR row pointers: degree[i] = indptr[i+1] - indptr[i].""" - return (indptr[1:] - indptr[:-1]).to(torch.int32) + return _clamp_to_int32(indptr[1:] - indptr[:-1]) def _get_degree_reduce_context() -> tuple[int, torch.device]: @@ -194,9 +200,9 @@ def _all_reduce_single_degree_tensor( padded = _pad_to_size(tensor, max_size).to(torch.int64).to(device) torch.distributed.all_reduce(padded, op=torch.distributed.ReduceOp.SUM) - # Correct for over-counting and move back to CPU. We keep int32 so high-degree - # nodes do not saturate at int16. - return (padded // local_world_size).to(torch.int32).cpu() + # Correct for over-counting and move back to CPU. Clamp before casting so + # high-degree nodes saturate instead of wrapping. + return _clamp_to_int32(padded // local_world_size).cpu() @overload diff --git a/tests/unit/distributed/utils/degree_test.py b/tests/unit/distributed/utils/degree_test.py index 809590093..465415202 100644 --- a/tests/unit/distributed/utils/degree_test.py +++ b/tests/unit/distributed/utils/degree_test.py @@ -5,6 +5,7 @@ from parameterized import param, parameterized from gigl.distributed.utils.degree import ( + _clamp_to_int32, _compute_degrees_from_indptr, _pad_to_size, compute_and_broadcast_degree_tensor, @@ -236,6 +237,7 @@ def test_degree_tensor_homogeneous(self): assert isinstance(result, torch.Tensor) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) + self.assertTrue(result.is_shared()) self.assert_tensor_equality(result, expected) def test_degree_tensor_caches_result(self): @@ -262,6 +264,7 @@ def test_degree_tensor_heterogeneous(self): self.assertEqual(set(result.keys()), set(expected.keys())) for node_type, expected_degrees in expected.items(): + self.assertTrue(result[node_type].is_shared()) self.assert_tensor_equality(result[node_type], expected_degrees) @@ -302,6 +305,16 @@ def test_compute_degrees_from_indptr(self): result = _compute_degrees_from_indptr(indptr) self.assert_tensor_equality(result, expected) + def test_clamp_to_int32(self): + """Test that large degree values clamp before conversion.""" + int32_max = torch.iinfo(torch.int32).max + tensor = torch.tensor([0, int32_max, int32_max + 1], dtype=torch.int64) + expected = torch.tensor([0, int32_max, int32_max], dtype=torch.int32) + + result = _clamp_to_int32(tensor) + + self.assert_tensor_equality(result, expected) + def test_compute_degrees_from_indptr_all_zeros(self): """Test _compute_degrees_from_indptr with all-zero indptr (no edges).""" # All-zero indptr means no outgoing edges for any node From 2f35f22983baf28aa2c13c0213604f6feb6b6b1e Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 20:46:38 +0000 Subject: [PATCH 26/32] Configure graph-store PPR sampler options inline --- .../e2e_hom_cora_sup_gs_ppr_task_config.yaml | 28 +++++++------ .../graph_store/homogeneous_inference.py | 10 +++-- .../graph_store/homogeneous_training.py | 10 +++-- gigl/utils/sampling.py | 41 ------------------- 4 files changed, 30 insertions(+), 59 deletions(-) diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml index 46c508819..ad1ab8a4a 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml @@ -17,12 +17,14 @@ trainerConfig: trainerArgs: log_every_n_batch: "1" num_neighbors: "[10, 10]" - sampler_type: "ppr" - ppr_alpha: "0.5" - ppr_eps: "0.0001" - ppr_max_nodes: "20" - ppr_neighbors_per_hop: "100" - ppr_max_fetch_iterations: "2" + ppr_sampler_options: >- + { + "alpha": 0.5, + "eps": 0.0001, + "max_ppr_nodes": 20, + "num_neighbors_per_hop": 100, + "max_fetch_iterations": 2 + } sampling_workers_per_process: "2" main_batch_size: "8" random_batch_size: "8" @@ -47,12 +49,14 @@ inferencerConfig: inferencerArgs: log_every_n_batch: "1" num_neighbors: "[10, 10]" - sampler_type: "ppr" - ppr_alpha: "0.5" - ppr_eps: "0.0001" - ppr_max_nodes: "20" - ppr_neighbors_per_hop: "100" - ppr_max_fetch_iterations: "2" + ppr_sampler_options: >- + { + "alpha": 0.5, + "eps": 0.0001, + "max_ppr_nodes": 20, + "num_neighbors_per_hop": 100, + "max_fetch_iterations": 2 + } sampling_workers_per_inference_process: "2" inferenceBatchSize: 256 command: python -m examples.link_prediction.graph_store.homogeneous_inference diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 5faa84b72..16333c75f 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -83,6 +83,7 @@ import argparse import gc +import json import os import sys import time @@ -101,7 +102,7 @@ from gigl.common.utils.gcs import GcsUtils from gigl.distributed.graph_store.compute import init_compute_process from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset -from gigl.distributed.sampler_options import SamplerOptions +from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.distributed.utils import get_graph_store_info from gigl.env.distributed import GraphStoreInfo from gigl.nn import LinkPredictionGNN @@ -111,7 +112,7 @@ from gigl.src.common.utils.bq import BqUtils from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.inference.lib.assets import InferenceAssets -from gigl.utils.sampling import parse_fanout, parse_sampler_options +from gigl.utils.sampling import parse_fanout logger = Logger() @@ -490,7 +491,10 @@ def _run_example_inference( # Parses the fanout as a string. For the homogeneous case, the fanouts should be specified # as a string of a list of integers, such as "[10, 10]". num_neighbors = parse_fanout(inferencer_args.get("num_neighbors", "[10, 10]")) - sampler_options = parse_sampler_options(inferencer_args) + sampler_options: Optional[SamplerOptions] = None + sampler_options_args = inferencer_args.get("ppr_sampler_options") + if sampler_options_args is not None and sampler_options_args.strip(): + sampler_options = PPRSamplerOptions(**json.loads(sampler_options_args)) # While the ideal value for `sampling_workers_per_inference_process` has been identified to be # between `2` and `4`, this may need some tuning depending on the pipeline. We default this diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index c7ae356cc..cdfd6d93b 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -119,6 +119,7 @@ import argparse import gc +import json import os import statistics import sys @@ -143,7 +144,7 @@ shutdown_compute_process, ) from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset -from gigl.distributed.sampler_options import SamplerOptions +from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.distributed.utils import get_available_device, get_graph_store_info from gigl.env.distributed import GraphStoreInfo from gigl.nn import LinkPredictionGNN, RetrievalLoss @@ -159,7 +160,7 @@ from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict from gigl.utils.iterator import InfiniteIterator -from gigl.utils.sampling import parse_fanout, parse_sampler_options +from gigl.utils.sampling import parse_fanout logger = Logger() @@ -867,7 +868,10 @@ def _run_example_training( fanout = trainer_args.get("num_neighbors", "[10, 10]") num_neighbors = parse_fanout(fanout) - sampler_options = parse_sampler_options(trainer_args) + sampler_options: Optional[SamplerOptions] = None + sampler_options_args = trainer_args.get("ppr_sampler_options") + if sampler_options_args is not None and sampler_options_args.strip(): + sampler_options = PPRSamplerOptions(**json.loads(sampler_options_args)) sampling_workers_per_process: int = int( trainer_args.get("sampling_workers_per_process", "4") diff --git a/gigl/utils/sampling.py b/gigl/utils/sampling.py index e2c6996e5..5d0ed6a44 100644 --- a/gigl/utils/sampling.py +++ b/gigl/utils/sampling.py @@ -1,12 +1,10 @@ import ast -from collections.abc import Mapping from dataclasses import dataclass from typing import Any, Optional, Union import torch from gigl.common.logger import Logger -from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.src.common.types.graph_data import EdgeType, NodeType logger = Logger() @@ -90,45 +88,6 @@ def parse_fanout(fanout_str: str) -> Union[list[int], dict[EdgeType, list[int]]] ) -def _parse_optional_int(value: Optional[str]) -> Optional[int]: - if value is None: - return None - normalized = value.strip().lower() - if normalized in {"", "none", "null"}: - return None - return int(value) - - -def parse_sampler_options(args: Mapping[str, str]) -> Optional[SamplerOptions]: - sampler_type = args.get("sampler_type", "khop").strip().lower().replace("-", "_") - if sampler_type == "": - sampler_type = "khop" - - if sampler_type in {"khop", "k_hop", "neighbor", "neighbor_sampler"}: - return None - - if sampler_type != "ppr": - raise ValueError( - f"Unsupported sampler_type={sampler_type}. Expected one of: khop, ppr." - ) - - max_ppr_nodes = args.get("ppr_max_nodes") - if max_ppr_nodes is None: - max_ppr_nodes = args.get("ppr_max_ppr_nodes", "50") - - num_neighbors_per_hop = args.get("ppr_neighbors_per_hop") - if num_neighbors_per_hop is None: - num_neighbors_per_hop = args.get("ppr_num_neighbors_per_hop", "1000") - - return PPRSamplerOptions( - alpha=float(args.get("ppr_alpha", "0.5")), - eps=float(args.get("ppr_eps", "0.0001")), - max_ppr_nodes=int(max_ppr_nodes), - num_neighbors_per_hop=int(num_neighbors_per_hop), - max_fetch_iterations=_parse_optional_int(args.get("ppr_max_fetch_iterations")), - ) - - @dataclass(frozen=True) class ABLPInputNodes: """Represents ABLP (Anchor Based Link Prediction) input for a single storage server. From aa42d7abd4f6f16816d812a005d359e091706684 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 21:26:03 +0000 Subject: [PATCH 27/32] Comments --- gigl/distributed/utils/neighborloader.py | 17 ++++++++--------- .../distributed/utils/neighborloader_test.py | 5 +++++ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/gigl/distributed/utils/neighborloader.py b/gigl/distributed/utils/neighborloader.py index b24de6733..00d08c79d 100644 --- a/gigl/distributed/utils/neighborloader.py +++ b/gigl/distributed/utils/neighborloader.py @@ -338,12 +338,11 @@ def attach_ppr_outputs( ppr_edge_indices: dict[EdgeType, torch.Tensor], ppr_weights: dict[EdgeType, torch.Tensor], ) -> None: - """Attach PPR edge indices and weights onto a HeteroData object. + """Attach PPR edge indices and weights onto a Data/HeteroData object. For each PPR edge type, sets ``data[edge_type].edge_index`` and ``data[edge_type].edge_attr`` in-place. Called from the loader's - ``_collate_fn`` only when a PPR sampler is active; the function is a - no-op if both dicts are empty. + ``_collate_fn`` only when a PPR sampler is active. Args: data: The Data or HeteroData object to attach outputs to. @@ -352,21 +351,21 @@ def attach_ppr_outputs( Raises: AssertionError: If ``ppr_edge_indices`` and ``ppr_weights`` have different edge-type keys. + ValueError: If homogeneous ``Data`` does not have exactly one PPR edge type. """ assert ppr_edge_indices.keys() == ppr_weights.keys(), ( f"PPR edge index and weight edge types must match, " f"got {set(ppr_edge_indices.keys())} vs {set(ppr_weights.keys())}" ) if isinstance(data, Data): - if len(ppr_edge_indices) > 1: + if len(ppr_edge_indices) != 1: raise ValueError( - "Expected at most one PPR edge type for homogeneous Data output, " + "Expected exactly one PPR edge type for homogeneous Data output, " f"got {set(ppr_edge_indices.keys())}" ) - if ppr_edge_indices: - edge_type = next(iter(ppr_edge_indices)) - data.edge_index = ppr_edge_indices[edge_type] - data.edge_attr = ppr_weights[edge_type] + edge_type = next(iter(ppr_edge_indices)) + data.edge_index = ppr_edge_indices[edge_type] + data.edge_attr = ppr_weights[edge_type] # Homogeneous Data has no per-edge-type stores; the PPR edges are attached. return diff --git a/tests/unit/distributed/utils/neighborloader_test.py b/tests/unit/distributed/utils/neighborloader_test.py index 8de5a2f79..a9b9c023f 100644 --- a/tests/unit/distributed/utils/neighborloader_test.py +++ b/tests/unit/distributed/utils/neighborloader_test.py @@ -11,6 +11,7 @@ POSITIVE_LABEL_METADATA_KEY, ) from gigl.distributed.utils.neighborloader import ( + attach_ppr_outputs, extract_edge_type_metadata, extract_metadata, labeled_to_homogeneous, @@ -250,6 +251,10 @@ def test_strip_non_ppr_edge_types(self): self.assertIn(_PPR_U2I, result.edge_types) self.assertIn(_PPR_U2U, result.edge_types) + def test_attach_ppr_outputs_requires_one_homogeneous_edge_type(self): + with self.assertRaises(ValueError): + attach_ppr_outputs(Data(), {}, {}) + @parameterized.expand( [ param( From 188525fb745d310ba3975f9480e7b658b7fcc069 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 21:44:34 +0000 Subject: [PATCH 28/32] Clarify graph-store PPR sampler args --- .../configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml | 8 ++++++++ .../link_prediction/graph_store/homogeneous_inference.py | 4 ++-- .../link_prediction/graph_store/homogeneous_training.py | 4 ++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml index ad1ab8a4a..1e440dc7f 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml @@ -17,6 +17,10 @@ trainerConfig: trainerArgs: log_every_n_batch: "1" num_neighbors: "[10, 10]" + # Parsed in the graph-store training entrypoint and passed directly as + # kwargs to PPRSamplerOptions in gigl/distributed/sampler_options.py. + # Presence of ppr_sampler_options selects PPR; otherwise this example uses + # k-hop sampling configured by num_neighbors. ppr_sampler_options: >- { "alpha": 0.5, @@ -49,6 +53,10 @@ inferencerConfig: inferencerArgs: log_every_n_batch: "1" num_neighbors: "[10, 10]" + # Parsed in the graph-store inference entrypoint and passed directly as + # kwargs to PPRSamplerOptions in gigl/distributed/sampler_options.py. + # Presence of ppr_sampler_options selects PPR; otherwise this example uses + # k-hop sampling configured by num_neighbors. ppr_sampler_options: >- { "alpha": 0.5, diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 16333c75f..26dbad8e9 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -82,8 +82,8 @@ """ import argparse +import ast import gc -import json import os import sys import time @@ -494,7 +494,7 @@ def _run_example_inference( sampler_options: Optional[SamplerOptions] = None sampler_options_args = inferencer_args.get("ppr_sampler_options") if sampler_options_args is not None and sampler_options_args.strip(): - sampler_options = PPRSamplerOptions(**json.loads(sampler_options_args)) + sampler_options = PPRSamplerOptions(**ast.literal_eval(sampler_options_args)) # While the ideal value for `sampling_workers_per_inference_process` has been identified to be # between `2` and `4`, this may need some tuning depending on the pipeline. We default this diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index cdfd6d93b..8f601399f 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -118,8 +118,8 @@ """ import argparse +import ast import gc -import json import os import statistics import sys @@ -871,7 +871,7 @@ def _run_example_training( sampler_options: Optional[SamplerOptions] = None sampler_options_args = trainer_args.get("ppr_sampler_options") if sampler_options_args is not None and sampler_options_args.strip(): - sampler_options = PPRSamplerOptions(**json.loads(sampler_options_args)) + sampler_options = PPRSamplerOptions(**ast.literal_eval(sampler_options_args)) sampling_workers_per_process: int = int( trainer_args.get("sampling_workers_per_process", "4") From 2641834a5c5860e4bfaf61df61692129f15ca705 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 21:50:20 +0000 Subject: [PATCH 29/32] Document PPR degree tensor dtype rationale --- gigl/distributed/utils/degree.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index 102c3c712..f88d3c0c6 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -27,9 +27,12 @@ non-label edge types on every rank, even when a rank has no local edges for a type. This keeps the per-node-type all-reduce order consistent across ranks. -Degree tensors are stored as int32 because the PPR C++ sampler expects standard -int32/int64 integer tensors; int32 keeps memory usage lower than int64. Values -above the int32 maximum are clamped before casting to avoid wraparound. +Degree tensors are stored as int32 to stay aligned with the PPR C++ sampler's +total-degree tensor requirement while keeping memory lower than int64. We avoid +int16 because it has caused compatibility issues in this path: during the C++ +PPR sampler migration, ``torch.distributed.all_reduce`` on an int16 tensor +produced ``RuntimeError: Invalid scalar type``. Values above the int32 maximum +are clamped before casting to avoid wraparound. """ from collections import Counter From 1ff863532e5ede5a9b1af46ffc766b6f4c6e0f92 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 21:58:58 +0000 Subject: [PATCH 30/32] Address remaining comments --- gigl/distributed/base_dist_loader.py | 2 ++ gigl/distributed/dist_dataset.py | 4 ---- gigl/distributed/dist_ppr_sampler.py | 6 +++--- gigl/distributed/graph_store/dist_server.py | 9 ++++++--- .../graph_store/shared_dist_sampling_producer.py | 3 ++- tests/unit/distributed/utils/degree_test.py | 2 -- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index c298aa0d2..35b33c0c9 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -54,6 +54,7 @@ patch_fanout_for_sampling, ) from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE +from gigl.utils.share_memory import share_memory logger = Logger() @@ -466,6 +467,7 @@ def create_mp_producer( channel = BaseDistLoader.create_colocated_channel(worker_options) if isinstance(sampler_options, PPRSamplerOptions): degree_tensors = dataset.degree_tensor + share_memory(degree_tensors) if isinstance(degree_tensors, dict): logger.info( f"Pre-computed degree tensors for PPR sampling across " diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index 1ba437e73..181c2c7d9 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -330,9 +330,6 @@ def degree_tensor( The result is cached for subsequent accesses. - The cached degree tensor is moved to shared memory before being returned, - so spawned workers can share a single allocation. - Returns: Union[torch.Tensor, dict[NodeType, torch.Tensor]]: Degree tensor for homogeneous graphs, or total degree tensors keyed by node type @@ -349,7 +346,6 @@ def degree_tensor( self._degree_tensor = compute_and_broadcast_degree_tensor( self.graph, self._edge_dir ) - share_memory(entity=self._degree_tensor) return self._degree_tensor @property diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index cb007c291..f325e2032 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -76,9 +76,9 @@ class DistPPRNeighborSampler(BaseDistNeighborSampler): num_neighbors_per_hop: Maximum number of neighbors to fetch per hop. degree_tensors: Pre-computed total-degree tensors (int32). Homogeneous graphs use a single tensor; heterogeneous graphs use tensors keyed - by NodeType. Must be pre-computed by the caller through - ``DistDataset.degree_tensor`` so workers share a single allocation - rather than recomputing per-worker. + by NodeType. The colocated and graph-store loader paths retrieve + these through ``DistDataset.degree_tensor`` and move them to shared + memory before worker handoff. """ def __init__( diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 3e100022e..3c8a87ee1 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -96,6 +96,7 @@ def compute_process(): from gigl.src.common.types.graph_data import EdgeType, NodeType from gigl.types.graph import FeatureInfo, select_label_edge_types from gigl.utils.data_splitters import get_labels_for_anchor_nodes +from gigl.utils.share_memory import share_memory SERVER_EXIT_STATUS_CHECK_INTERVAL = 5.0 r""" Interval (in seconds) to check exit status of server. @@ -603,6 +604,10 @@ def init_sampling_backend(self, opts: InitSamplingBackendRequest) -> int: else: backend_id = self._next_backend_id self._next_backend_id += 1 + degree_tensors = None + if isinstance(opts.sampler_options, PPRSamplerOptions): + degree_tensors = self.dataset.degree_tensor + share_memory(degree_tensors) backend_state = SamplingBackendState( backend_id=backend_id, backend_key=opts.backend_key, @@ -612,9 +617,7 @@ def init_sampling_backend(self, opts: InitSamplingBackendRequest) -> int: sampling_config=opts.sampling_config, sampler_options=opts.sampler_options, # We only need degree tensor for PPR sampling - degree_tensors=self.dataset.degree_tensor - if isinstance(opts.sampler_options, PPRSamplerOptions) - else None, + degree_tensors=degree_tensors, ), ) self._backend_id_by_backend_key[opts.backend_key] = backend_id diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index 4823650cf..7b1c2831c 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -364,7 +364,8 @@ def _shared_sampling_worker_loop( for PPR-based sampling). degree_tensors: Pre-computed degree tensors for PPR sampling, or ``None`` for non-PPR samplers. Materialized once in the parent via - ``DistDataset.degree_tensor`` and shared across workers. + ``DistDataset.degree_tensor`` and moved to shared memory before + backend construction. Algorithm: 1. Initialize RPC, sampler infrastructure, and signal the parent via barrier. diff --git a/tests/unit/distributed/utils/degree_test.py b/tests/unit/distributed/utils/degree_test.py index 465415202..70ee02f18 100644 --- a/tests/unit/distributed/utils/degree_test.py +++ b/tests/unit/distributed/utils/degree_test.py @@ -237,7 +237,6 @@ def test_degree_tensor_homogeneous(self): assert isinstance(result, torch.Tensor) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assertTrue(result.is_shared()) self.assert_tensor_equality(result, expected) def test_degree_tensor_caches_result(self): @@ -264,7 +263,6 @@ def test_degree_tensor_heterogeneous(self): self.assertEqual(set(result.keys()), set(expected.keys())) for node_type, expected_degrees in expected.items(): - self.assertTrue(result[node_type].is_shared()) self.assert_tensor_equality(result[node_type], expected_degrees) From 554826005c4c155065a1a0be5c8fe008ada55556 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Mon, 1 Jun 2026 18:03:33 +0000 Subject: [PATCH 31/32] Fix --- gigl/distributed/dist_ppr_sampler.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index f325e2032..c2b8d5ada 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -649,8 +649,14 @@ async def _sample_from_nodes( ppr_edge_index = torch.stack([rows, cols]) - metadata["edge_index"] = ppr_edge_index - metadata["edge_attr"] = homo_flat_weights + homo_ppr_edge_type = ( + DEFAULT_HOMOGENEOUS_NODE_TYPE, + "ppr", + DEFAULT_HOMOGENEOUS_NODE_TYPE, + ) + etype_str = repr(homo_ppr_edge_type) + metadata[f"{PPR_EDGE_INDEX_METADATA_KEY}{etype_str}"] = ppr_edge_index + metadata[f"{PPR_WEIGHT_METADATA_KEY}{etype_str}"] = homo_flat_weights sample_output = SamplerOutput( node=all_nodes, From a9df2856fdeae70d8d416b80f5c42fe836ea8e55 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Mon, 1 Jun 2026 18:18:32 +0000 Subject: [PATCH 32/32] Improve solution --- gigl/distributed/base_dist_loader.py | 50 +++++++++++++++++++ gigl/distributed/dist_ablp_neighborloader.py | 18 +------ gigl/distributed/dist_ppr_sampler.py | 10 +--- .../distributed/distributed_neighborloader.py | 19 +------ 4 files changed, 54 insertions(+), 43 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 35b33c0c9..09b4dcac0 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -39,6 +39,10 @@ from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.dist_ppr_sampler import ( + PPR_EDGE_INDEX_METADATA_KEY, + PPR_WEIGHT_METADATA_KEY, +) from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.compute import async_request_server from gigl.distributed.graph_store.dist_server import DistServer @@ -51,7 +55,10 @@ from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.distributed.utils.neighborloader import ( DatasetSchema, + attach_ppr_outputs, + extract_edge_type_metadata, patch_fanout_for_sampling, + strip_non_ppr_edge_types, ) from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE from gigl.utils.share_memory import share_memory @@ -918,6 +925,49 @@ def shutdown(self) -> None: ) self._shutdowned = True + def _apply_ppr_outputs( + self, + data: Union[Data, HeteroData], + metadata: dict[str, torch.Tensor], + ) -> tuple[Union[Data, HeteroData], dict[str, torch.Tensor]]: + """Attach PPR edge outputs from metadata onto the data object. + + For pure homogeneous graphs the PPR sampler writes plain ``"edge_index"`` + and ``"edge_attr"`` keys. For heterogeneous graphs (including + labeled-homogeneous graphs that have been converted to ``Data`` by + ``labeled_to_homogeneous``) it writes prefixed edge-type keys that are + parsed by ``extract_edge_type_metadata``. + + A no-op when the active sampler is not ``PPRSamplerOptions``. + + Args: + data: The Data or HeteroData object to attach PPR outputs to. + metadata: Remaining metadata dict; consumed entries are removed. + + Returns: + Updated ``(data, metadata)`` tuple. + """ + if not isinstance(self._sampler_options, PPRSamplerOptions): + return data, metadata + + if not self._is_homogeneous_with_labeled_edge_type and isinstance(data, Data): + # Pure homogeneous PPR: sampler writes plain "edge_index"/"edge_attr". + data.edge_index = metadata.pop("edge_index") + data.edge_attr = metadata.pop("edge_attr") + else: + # Hetero PPR (including labeled-homogeneous): prefixed edge-type keys. + matched, metadata = extract_edge_type_metadata( + metadata=metadata, + prefixes=[PPR_EDGE_INDEX_METADATA_KEY, PPR_WEIGHT_METADATA_KEY], + ) + ppr_edge_indices = matched[PPR_EDGE_INDEX_METADATA_KEY] + ppr_weights = matched[PPR_WEIGHT_METADATA_KEY] + attach_ppr_outputs(data, ppr_edge_indices, ppr_weights) + if isinstance(data, HeteroData): + data = strip_non_ppr_edge_types(data, set(ppr_edge_indices.keys())) + + return data, metadata + def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: """Override GLT's _collate_fn to optionally batch-transfer tensors with non_blocking=True. diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index fa572a5d5..50f42f5a9 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -16,10 +16,6 @@ from gigl.distributed.base_dist_loader import BaseDistLoader from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset -from gigl.distributed.dist_ppr_sampler import ( - PPR_EDGE_INDEX_METADATA_KEY, - PPR_WEIGHT_METADATA_KEY, -) from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.sampler import ( @@ -28,21 +24,18 @@ ABLPNodeSamplerInput, ) from gigl.distributed.sampler_options import ( - PPRSamplerOptions, SamplerOptions, resolve_sampler_options, ) from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, - attach_ppr_outputs, extract_edge_type_metadata, extract_metadata, labeled_to_homogeneous, set_missing_features, shard_nodes_by_process, strip_label_edges, - strip_non_ppr_edge_types, ) from gigl.src.common.types.graph_data import ( NodeType, # TODO (mkolodner-sc): Change to use torch_geometric.typing @@ -882,16 +875,7 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: data = self._set_labels(data, positive_labels, negative_labels) - if isinstance(self._sampler_options, PPRSamplerOptions): - matched_ppr, metadata = extract_edge_type_metadata( - metadata=metadata, - prefixes=[PPR_EDGE_INDEX_METADATA_KEY, PPR_WEIGHT_METADATA_KEY], - ) - ppr_edge_indices = matched_ppr[PPR_EDGE_INDEX_METADATA_KEY] - ppr_weights = matched_ppr[PPR_WEIGHT_METADATA_KEY] - attach_ppr_outputs(data, ppr_edge_indices, ppr_weights) - if isinstance(data, HeteroData): - data = strip_non_ppr_edge_types(data, set(ppr_edge_indices.keys())) + data, metadata = self._apply_ppr_outputs(data, metadata) # Attach any remaining metadata (e.g. custom user-defined keys) directly onto the # data object so downstream code can access them via attribute lookup. diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index c2b8d5ada..f325e2032 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -649,14 +649,8 @@ async def _sample_from_nodes( ppr_edge_index = torch.stack([rows, cols]) - homo_ppr_edge_type = ( - DEFAULT_HOMOGENEOUS_NODE_TYPE, - "ppr", - DEFAULT_HOMOGENEOUS_NODE_TYPE, - ) - etype_str = repr(homo_ppr_edge_type) - metadata[f"{PPR_EDGE_INDEX_METADATA_KEY}{etype_str}"] = ppr_edge_index - metadata[f"{PPR_WEIGHT_METADATA_KEY}{etype_str}"] = homo_flat_weights + metadata["edge_index"] = ppr_edge_index + metadata["edge_attr"] = homo_flat_weights sample_output = SamplerOutput( node=all_nodes, diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 2dcf65a55..0ef76ac9e 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -18,28 +18,20 @@ from gigl.distributed.base_dist_loader import BaseDistLoader from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset -from gigl.distributed.dist_ppr_sampler import ( - PPR_EDGE_INDEX_METADATA_KEY, - PPR_WEIGHT_METADATA_KEY, -) from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.sampler_options import ( - PPRSamplerOptions, SamplerOptions, resolve_sampler_options, ) from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, - attach_ppr_outputs, - extract_edge_type_metadata, extract_metadata, labeled_to_homogeneous, set_missing_features, shard_nodes_by_process, strip_label_edges, - strip_non_ppr_edge_types, ) from gigl.src.common.types.graph_data import ( NodeType, # TODO (mkolodner-sc): Change to use torch_geometric.typing @@ -557,16 +549,7 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: if self._is_homogeneous_with_labeled_edge_type: data = labeled_to_homogeneous(DEFAULT_HOMOGENEOUS_EDGE_TYPE, data) - if isinstance(self._sampler_options, PPRSamplerOptions): - matched, metadata = extract_edge_type_metadata( - metadata=metadata, - prefixes=[PPR_EDGE_INDEX_METADATA_KEY, PPR_WEIGHT_METADATA_KEY], - ) - ppr_edge_indices = matched[PPR_EDGE_INDEX_METADATA_KEY] - ppr_weights = matched[PPR_WEIGHT_METADATA_KEY] - attach_ppr_outputs(data, ppr_edge_indices, ppr_weights) - if isinstance(data, HeteroData): - data = strip_non_ppr_edge_types(data, set(ppr_edge_indices.keys())) + data, metadata = self._apply_ppr_outputs(data, metadata) # Attach any remaining metadata (e.g. custom user-defined keys) directly onto the # data object so downstream code can access them via attribute lookup.