-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathnode2vec.py
More file actions
73 lines (44 loc) · 2 KB
/
node2vec.py
File metadata and controls
73 lines (44 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# -*- coding:utf-8 -*-
"""
Author:
Weichen Shen,wcshen1994@163.com
Reference:
[1] Grover A, Leskovec J. node2vec: Scalable feature learning for networks[C]//Proceedings of the 22nd ACM SIGKDD international conference on Knowledge discovery and data mining. ACM, 2016: 855-864.(https://www.kdd.org/kdd2016/papers/files/rfp0218-groverA.pdf)
"""
from gensim.models import Word2Vec, word2vec
import pandas as pd
from ..walker import RandomWalker
class Node2Vec:
def __init__(self, graph, outlier, walk_length, num_walks, p=1.0, q=1.0, workers=1, use_rejection_sampling=0):
self.graph = graph
self._embeddings = {}
self.walker = RandomWalker(
graph, p=p, q=q, use_rejection_sampling=use_rejection_sampling)
print("Preprocess transition probs...")
self.walker.preprocess_transition_probs()
self.sentences = self.walker.simulate_walks("node", outlier,
num_walks=num_walks, walk_length=walk_length, workers=workers, verbose=1, weight=False)
def train(self, walkfile, embed_size=128, window_size=5, workers=3, iter=5, sg = 1, hs = 1, **kwargs):
kwargs["sentences"] = word2vec.Text8Corpus(walkfile)
kwargs["min_count"] = kwargs.get("min_count", 1)
kwargs["vector_size"] = embed_size
kwargs["sg"] = sg
kwargs["hs"] = hs # node2vec not use Hierarchical Softmax
kwargs["workers"] = workers
kwargs["window"] = window_size
kwargs["epochs"] = iter
print("Learning embedding vectors...")
model = Word2Vec(**kwargs)
print("Learning embedding vectors done!")
self.w2v_model = model
return model
def get_embeddings(self,):
if self.w2v_model is None:
print("model not train")
return {}
self._embeddings = {}
for word in self.graph.nodes():
self._embeddings[word] = self.w2v_model.wv[word]
return self._embeddings
def get_sentences(self):
return self.sentences