Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions gigl/distributed/dist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,16 +499,11 @@ def _initialize_node_ids(
)
else:
train_nodes, val_nodes, test_nodes = splits
self._num_train = (
train_nodes.numel() # ty: ignore[unresolved-attribute]
)
self._num_val = val_nodes.numel() # ty: ignore[unresolved-attribute]
self._num_test = test_nodes.numel() # ty: ignore[unresolved-attribute]
self._num_train = train_nodes.numel()
self._num_val = val_nodes.numel()
self._num_test = test_nodes.numel()
self._node_ids = _append_non_split_node_ids(
train_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
val_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
test_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
node_ids_on_machine,
train_nodes, val_nodes, test_nodes, node_ids_on_machine
)
else:
logger.info(
Expand Down Expand Up @@ -642,8 +637,8 @@ def _initialize_node_features(
# if it is not an edge type, since it must be one of the two.
assert not isinstance(node_type, EdgeType)
self._node_feature_info[node_type] = FeatureInfo(
dim=node_features_per_node_type.size(1), # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
dtype=node_features_per_node_type.dtype, # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
dim=node_features_per_node_type.size(1),
dtype=node_features_per_node_type.dtype,
)
logger.info(
f"Initialized node features for heterogeneous graph to dataset with node types: {node_features.keys()}"
Expand Down Expand Up @@ -725,8 +720,8 @@ def _initialize_edge_features(
for edge_type, edge_features_per_edge_type in edge_features.items():
assert isinstance(edge_type, EdgeType)
self._edge_feature_info[edge_type] = FeatureInfo(
dim=edge_features_per_edge_type.size(1), # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
dtype=edge_features_per_edge_type.dtype, # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
dim=edge_features_per_edge_type.size(1),
dtype=edge_features_per_edge_type.dtype,
)
logger.info(
f"Initialized edge features for heterogeneous graph to dataset with edge types: {edge_features.keys()}"
Expand Down
10 changes: 5 additions & 5 deletions gigl/distributed/dist_ppr_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ async def _sample_from_nodes(
seed_types = list(nodes_to_sample.keys())
ppr_results = await asyncio.gather(
*[
self._compute_ppr_scores(nodes_to_sample[seed_type], seed_type) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
self._compute_ppr_scores(nodes_to_sample[seed_type], seed_type)
for seed_type in seed_types
]
)
Expand All @@ -556,20 +556,20 @@ async def _sample_from_nodes(

for ntype, flat_ids in ntype_to_flat_ids.items():
ppr_edge_type: EdgeType = (seed_type, "ppr", ntype)
valid_counts = ntype_to_valid_counts[ntype] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
valid_counts = ntype_to_valid_counts[ntype]
ppr_edge_type_to_flat_weights[ppr_edge_type] = (
ntype_to_flat_weights[ntype] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
ntype_to_flat_weights[ntype]
)

# Skip empty pairs; induce_next handles deduplication across
# seed types so a neighbor reachable from multiple seed types
# gets one consistent local index in node_dict[ntype].
if flat_ids.numel() > 0: # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
if flat_ids.numel() > 0:
nbr_dict[ppr_edge_type] = [
src_dict[seed_type],
flat_ids,
valid_counts,
] # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes.
]

# induce_next processes all PPR edge types in nbr_dict in one
# pass, assigning local indices to neighbors not yet registered and
Expand Down
Loading