parallel walker

This commit is contained in:
Dongzy 2018-11-17 21:39:30 +08:00
parent b046b20090
commit cbf674e3cd

View File

@ -1,10 +1,10 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function from __future__ import print_function
import functools
import multiprocessing import multiprocessing
import random import random
import time import time
from itertools import chain
import numpy as np import numpy as np
from networkx import nx from networkx import nx
@ -40,105 +40,70 @@ class BiasedWalker: # ------ our method
# alias sampling for ABRW------------------------------------------------------------------- # alias sampling for ABRW-------------------------------------------------------------------
def simulate_walks(self, num_walks, walk_length): def simulate_walks(self, num_walks, walk_length):
self.P_G = nx.to_networkx_graph(self.P, create_using=nx.DiGraph()) # create a new nx graph based on ABRW transition prob matrix self.P_G = nx.to_networkx_graph(self.P, create_using=nx.DiGraph()) # create a new nx graph based on ABRW transition prob matrix
global P_G
P_G = self.P_G
t1 = time.time() t1 = time.time()
self.preprocess_transition_probs() # note: we simply adapt node2vec self.preprocess_transition_probs() # note: we simply adapt node2vec
t2 = time.time() t2 = time.time()
global alias_nodes
alias_nodes = self.alias_nodes
print('Time for construct alias table: {:.2f}'.format(t2-t1)) print('Time for construct alias table: {:.2f}'.format(t2-t1))
walks = [] walks = []
nodes = list(self.P_G.nodes()) nodes = list(self.P_G.nodes())
print('Walk iteration:') print('Walk iteration:')
pool = multiprocessing.Pool(self.workers)
for walk_iter in range(num_walks): for walk_iter in range(num_walks):
print(str(walk_iter+1), '/', str(num_walks)) print(str(walk_iter+1), '/', str(num_walks))
random.shuffle(nodes) # random.shuffle(nodes)
for node in nodes: walks += pool.map(functools.partial(node2vec_walk, walk_length=walk_length), nodes)
walks.append(self.node2vec_walk(walk_length=walk_length, start_node=node)) pool.close()
pool.join()
del alias_nodes, P_G
for i in range(len(walks)): # use ind to retrive orignal node ID for i in range(len(walks)): # use ind to retrive orignal node ID
for j in range(len(walks[0])): for j in range(len(walks[0])):
walks[i][j] = self.look_back_list[int(walks[i][j])] walks[i][j] = self.look_back_list[int(walks[i][j])]
return walks return walks
def node2vec_walk(self, walk_length, start_node): # to do...
G = self.P_G # more efficient way instead of copy from node2vec
alias_nodes = self.alias_nodes
walk = [start_node]
while len(walk) < walk_length:
cur = walk[-1]
cur_nbrs = list(G.neighbors(cur))
if len(cur_nbrs) > 0:
walk.append(cur_nbrs[alias_draw(alias_nodes[cur][0], alias_nodes[cur][1])])
else:
break
return walk
def preprocess_transition_probs(self): def preprocess_transition_probs(self):
G = self.P_G G = self.P_G
alias_nodes = {} alias_nodes = {}
for node in G.nodes(): nodes = G.nodes()
unnormalized_probs = [G[node][nbr]['weight'] for nbr in G.neighbors(node)]
norm_const = sum(unnormalized_probs) pool = multiprocessing.Pool(self.workers)
normalized_probs = [float(u_prob)/norm_const for u_prob in unnormalized_probs] alias_nodes = dict(zip(nodes, pool.map(get_alias_node, nodes)))
alias_nodes[node] = alias_setup(normalized_probs) pool.close()
pool.join()
self.alias_nodes = alias_nodes self.alias_nodes = alias_nodes
''' def node2vec_walk(start_node, walk_length): # to do...
#naive sampling for ABRW------------------------------------------------------------------- global P_G # more efficient way instead of copy from node2vec
def weighted_walk(self, start_node): global alias_nodes
# walk = [start_node]
#Simulate a weighted walk starting from start node. while len(walk) < walk_length:
# cur = walk[-1]
G = self.G cur_nbrs = list(P_G.neighbors(cur))
look_up_dict = self.look_up_dict if len(cur_nbrs) > 0:
look_back_list = self.look_back_list walk.append(cur_nbrs[alias_draw(alias_nodes[cur][0], alias_nodes[cur][1])])
node_size = self.node_size else:
walk = [start_node] break
return walk
while len(walk) < self.walk_length:
cur_node = walk[-1] #the last one entry/node
cur_ind = look_up_dict[cur_node] #key -> index
pdf = self.P[cur_ind,:] #the pdf of node with ind
#pdf = np.random.randn(18163)+10 #......test multiprocessor
#pdf = pdf / pdf.sum() #......test multiprocessor
#next_ind = int( np.array( nx.utils.random_sequence.discrete_sequence(n=1,distribution=pdf) ) )
next_ind = np.random.choice(len(pdf), 1, p=pdf)[0] #faster than nx
#next_ind = 0 #......test multiprocessor
next_node = look_back_list[next_ind] #index -> key
walk.append(next_node)
return walk
def simulate_walks(self, num_walks, walk_length): def get_alias_node(node):
# global P_G
#Repeatedly simulate weighted walks from each node. unnormalized_probs = [P_G[node][nbr]['weight'] for nbr in P_G.neighbors(node)]
# norm_const = sum(unnormalized_probs)
G = self.G normalized_probs = [float(u_prob)/norm_const for u_prob in unnormalized_probs]
self.num_walks = num_walks return alias_setup(normalized_probs)
self.walk_length = walk_length
self.walks = [] #what we all need later as input to skip-gram
nodes = list(G.nodes())
print('Walk iteration:')
for walk_iter in range(num_walks):
t1 = time.time()
random.shuffle(nodes)
for node in nodes: #for single cpu, if # of nodes < 2000 (speed up) or nodes > 20000 (avoid memory error)
self.walks.append(self.weighted_walk(node)) #for single cpu, if # of nodes < 2000 (speed up) or nodes > 20000 (avoid memory error)
#pool = multiprocessing.Pool(processes=3) #use all cpu by defalut or specify processes = xx
#self.walks.append(pool.map(self.weighted_walk, nodes)) #ref: https://stackoverflow.com/questions/8533318/multiprocessing-pool-when-to-use-apply-apply-async-or-map
#pool.close()
#pool.join()
t2 = time.time()
print(str(walk_iter+1), '/', str(num_walks), ' each itr last for: {:.2f}s'.format(t2-t1))
#self.walks = list(chain.from_iterable(self.walks)) #unlist...[[[x,x],[x,x]]] -> [x,x], [x,x]
return self.walks
'''
# ===========================================deepWalk-walker============================================ # ===========================================deepWalk-walker============================================
class BasicWalker: class BasicWalker:
def __init__(self, G, workers): def __init__(self, G, workers):
self.G = G.G self.G = G.G
self.node_size = G.get_num_nodes()
self.look_up_dict = G.look_up_dict self.look_up_dict = G.look_up_dict
def deepwalk_walk(self, walk_length, start_node): def deepwalk_walk(self, walk_length, start_node):