Merge pull request #36 from tinkle1129/master
add rejection sampling for node2vec
This commit is contained in:
commit
07783f4d20
@ -80,9 +80,9 @@ if __name__ == "__main__":
|
||||
G = nx.read_edgelist('../data/flight/brazil-airports.edgelist', create_using=nx.DiGraph(), nodetype=None,
|
||||
data=[('weight', int)])
|
||||
|
||||
model = Node2Vec(G, 10, 80, workers=1,p=0.25,q=2 )
|
||||
model = Node2Vec(G, 10, 80, workers=1, p=0.25, q=2, use_rejection_sampling=0)
|
||||
model.train()
|
||||
embeddings = model.get_embeddings()
|
||||
|
||||
evaluate_embeddings(embeddings)
|
||||
plot_embeddings(embeddings)
|
||||
plot_embeddings(embeddings)
|
||||
|
@ -44,9 +44,8 @@ def plot_embeddings(embeddings,):
|
||||
if __name__ == "__main__":
|
||||
G=nx.read_edgelist('../data/wiki/Wiki_edgelist.txt',
|
||||
create_using = nx.DiGraph(), nodetype = None, data = [('weight', int)])
|
||||
|
||||
model=Node2Vec(G, walk_length = 10, num_walks = 80,
|
||||
p = 0.25, q = 4, workers = 1)
|
||||
model = Node2Vec(G, walk_length=10, num_walks=80,
|
||||
p=0.25, q=4, workers=1, use_rejection_sampling=0)
|
||||
model.train(window_size = 5, iter = 3)
|
||||
embeddings=model.get_embeddings()
|
||||
|
||||
|
@ -26,11 +26,12 @@ from ..walker import RandomWalker
|
||||
|
||||
class Node2Vec:
|
||||
|
||||
def __init__(self, graph, walk_length, num_walks, p=1.0, q=1.0, workers=1):
|
||||
def __init__(self, graph, walk_length, num_walks, p=1.0, q=1.0, workers=1, use_rejection_sampling=0):
|
||||
|
||||
self.graph = graph
|
||||
self._embeddings = {}
|
||||
self.walker = RandomWalker(graph, p=p, q=q, )
|
||||
self.walker = RandomWalker(
|
||||
graph, p=p, q=q, use_rejection_sampling=use_rejection_sampling)
|
||||
|
||||
print("Preprocess transition probs...")
|
||||
self.walker.preprocess_transition_probs()
|
||||
|
71
ge/walker.py
71
ge/walker.py
@ -12,15 +12,17 @@ from .utils import partition_num
|
||||
|
||||
|
||||
class RandomWalker:
|
||||
def __init__(self, G, p=1, q=1):
|
||||
def __init__(self, G, p=1, q=1, use_rejection_sampling=0):
|
||||
"""
|
||||
:param G:
|
||||
:param p: Return parameter,controls the likelihood of immediately revisiting a node in the walk.
|
||||
:param q: In-out parameter,allows the search to differentiate between “inward” and “outward” nodes
|
||||
:param use_rejection_sampling: Whether to use the rejection sampling strategy in node2vec.
|
||||
"""
|
||||
self.G = G
|
||||
self.p = p
|
||||
self.q = q
|
||||
self.use_rejection_sampling = use_rejection_sampling
|
||||
|
||||
def deepwalk_walk(self, walk_length, start_node):
|
||||
|
||||
@ -61,6 +63,59 @@ class RandomWalker:
|
||||
|
||||
return walk
|
||||
|
||||
def node2vec_walk2(self, walk_length, start_node):
|
||||
"""
|
||||
Reference:
|
||||
KnightKing: A Fast Distributed Graph Random Walk Engine
|
||||
http://madsys.cs.tsinghua.edu.cn/publications/SOSP19-yang.pdf
|
||||
"""
|
||||
|
||||
def rejection_sample(inv_p, inv_q, nbrs_num):
|
||||
upper_bound = max(1.0, max(inv_p, inv_q))
|
||||
lower_bound = min(1.0, min(inv_p, inv_q))
|
||||
shatter = 0
|
||||
second_upper_bound = max(1.0, inv_q)
|
||||
if (inv_p > second_upper_bound):
|
||||
shatter = second_upper_bound / nbrs_num
|
||||
upper_bound = second_upper_bound + shatter
|
||||
return upper_bound, lower_bound, shatter
|
||||
|
||||
G = self.G
|
||||
alias_nodes = self.alias_nodes
|
||||
inv_p = 1.0 / self.p
|
||||
inv_q = 1.0 / self.q
|
||||
walk = [start_node]
|
||||
while len(walk) < walk_length:
|
||||
cur = walk[-1]
|
||||
cur_nbrs = list(G.neighbors(cur))
|
||||
if len(cur_nbrs) > 0:
|
||||
if len(walk) == 1:
|
||||
walk.append(
|
||||
cur_nbrs[alias_sample(alias_nodes[cur][0], alias_nodes[cur][1])])
|
||||
else:
|
||||
upper_bound, lower_bound, shatter = rejection_sample(
|
||||
inv_p, inv_q, len(cur_nbrs))
|
||||
prev = walk[-2]
|
||||
prev_nbrs = set(G.neighbors(prev))
|
||||
while True:
|
||||
prob = random.random() * upper_bound
|
||||
if (prob + shatter >= upper_bound):
|
||||
next_node = prev
|
||||
break
|
||||
next_node = cur_nbrs[alias_sample(
|
||||
alias_nodes[cur][0], alias_nodes[cur][1])]
|
||||
if (prob < lower_bound):
|
||||
break
|
||||
if (prob < inv_p and next_node == prev):
|
||||
break
|
||||
_prob = 1.0 if next_node in prev_nbrs else inv_q
|
||||
if (prob < _prob):
|
||||
break
|
||||
walk.append(next_node)
|
||||
else:
|
||||
break
|
||||
return walk
|
||||
|
||||
def simulate_walks(self, num_walks, walk_length, workers=1, verbose=0):
|
||||
|
||||
G = self.G
|
||||
@ -83,6 +138,9 @@ class RandomWalker:
|
||||
if self.p == 1 and self.q == 1:
|
||||
walks.append(self.deepwalk_walk(
|
||||
walk_length=walk_length, start_node=v))
|
||||
elif self.use_rejection_sampling:
|
||||
walks.append(self.node2vec_walk2(
|
||||
walk_length=walk_length, start_node=v))
|
||||
else:
|
||||
walks.append(self.node2vec_walk(
|
||||
walk_length=walk_length, start_node=v))
|
||||
@ -119,7 +177,6 @@ class RandomWalker:
|
||||
Preprocessing of transition probabilities for guiding the random walks.
|
||||
"""
|
||||
G = self.G
|
||||
|
||||
alias_nodes = {}
|
||||
for node in G.nodes():
|
||||
unnormalized_probs = [G[node][nbr].get('weight', 1.0)
|
||||
@ -129,14 +186,14 @@ class RandomWalker:
|
||||
float(u_prob)/norm_const for u_prob in unnormalized_probs]
|
||||
alias_nodes[node] = create_alias_table(normalized_probs)
|
||||
|
||||
alias_edges = {}
|
||||
if not self.use_rejection_sampling:
|
||||
alias_edges = {}
|
||||
|
||||
for edge in G.edges():
|
||||
alias_edges[edge] = self.get_alias_edge(edge[0], edge[1])
|
||||
for edge in G.edges():
|
||||
alias_edges[edge] = self.get_alias_edge(edge[0], edge[1])
|
||||
self.alias_edges = alias_edges
|
||||
|
||||
self.alias_nodes = alias_nodes
|
||||
self.alias_edges = alias_edges
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user