forked from haonan-yuan/RAG-GFM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsubgraph_dataset.py
More file actions
56 lines (41 loc) · 1.89 KB
/
subgraph_dataset.py
File metadata and controls
56 lines (41 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from torch_geometric.data import Dataset, Data
from torch_geometric.utils import k_hop_subgraph
from typing import Optional, List, Union
class SubgraphViewDataset(Dataset):
def __init__(self,
full_graph_data: Data,
cse_features: torch.Tensor,
top_k_node_indices: Union[torch.Tensor, List[int]]):
super().__init__()
self.full_graph_data = full_graph_data
self.cse_features = cse_features
self.top_k_node_indices = torch.tensor(top_k_node_indices, dtype=torch.long) if isinstance(top_k_node_indices, list) else top_k_node_indices
def len(self) -> int:
return len(self.top_k_node_indices)
def get(self, idx: int) -> tuple[Data, Data]:
center_node_idx = self.top_k_node_indices[idx].item()
subgraph_nodes, subgraph_edge_index, subgraph_mapping, subgraph_edge_mask = k_hop_subgraph(
node_idx=center_node_idx,
num_hops=1,
edge_index=self.full_graph_data.edge_index,
relabel_nodes=True,
num_nodes=self.full_graph_data.num_nodes
)
struct_features = self.cse_features[subgraph_nodes]
semantic_features = self.full_graph_data.x[subgraph_nodes]
struct_view = Data(
x=struct_features,
edge_index=subgraph_edge_index,
center_node_idx=center_node_idx,
subgraph_nodes=subgraph_nodes,
num_nodes=len(subgraph_nodes)
)
semantic_view = Data(
x=semantic_features,
edge_index=subgraph_edge_index,
center_node_idx=center_node_idx,
subgraph_nodes=subgraph_nodes,
num_nodes=len(subgraph_nodes)
)
return struct_view, semantic_view