Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
1fae0a956b | ||
|
4ea51c3370 | ||
|
4a0ac3a480 |
53
examples/deepwalk_wiki_csrgraph.py
Normal file
53
examples/deepwalk_wiki_csrgraph.py
Normal file
@ -0,0 +1,53 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ge.classify import read_node_label, Classifier
|
||||
from ge import DeepWalk
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
|
||||
def evaluate_embeddings(embeddings):
|
||||
X, Y = read_node_label('../data/wiki/wiki_labels.txt')
|
||||
tr_frac = 0.8
|
||||
print("Training classifier using {:.2f}% nodes...".format(
|
||||
tr_frac * 100))
|
||||
clf = Classifier(embeddings=embeddings, clf=LogisticRegression())
|
||||
clf.split_train_evaluate(X, Y, tr_frac)
|
||||
|
||||
|
||||
def plot_embeddings(embeddings,):
|
||||
X, Y = read_node_label('../data/wiki/wiki_labels.txt')
|
||||
|
||||
emb_list = []
|
||||
for k in X:
|
||||
emb_list.append(embeddings[k])
|
||||
emb_list = np.array(emb_list)
|
||||
|
||||
model = TSNE(n_components=2)
|
||||
node_pos = model.fit_transform(emb_list)
|
||||
|
||||
color_idx = {}
|
||||
for i in range(len(X)):
|
||||
color_idx.setdefault(Y[i][0], [])
|
||||
color_idx[Y[i][0]].append(i)
|
||||
|
||||
for c, idx in color_idx.items():
|
||||
plt.scatter(node_pos[idx, 0], node_pos[idx, 1], label=c)
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
G = nx.read_edgelist('../data/wiki/Wiki_edgelist.txt',
|
||||
create_using=nx.DiGraph(), nodetype=None, data=[('weight', int)])
|
||||
|
||||
model = DeepWalk(G, walk_length=10, num_walks=80, workers=1, use_csrgraph=True)
|
||||
model.train(window_size=5, iter=3)
|
||||
embeddings = model.get_embeddings()
|
||||
|
||||
evaluate_embeddings(embeddings)
|
||||
#plot_embeddings(embeddings)
|
@ -20,15 +20,31 @@ Reference:
|
||||
from ..walker import RandomWalker
|
||||
from gensim.models import Word2Vec
|
||||
import pandas as pd
|
||||
|
||||
|
||||
import numpy as np
|
||||
from csrgraph import csrgraph
|
||||
class DeepWalk:
|
||||
def __init__(self, graph, walk_length, num_walks, workers=1):
|
||||
|
||||
self.graph = graph
|
||||
def __init__(self, graph, walk_length, num_walks, workers=1,use_csrgraph=False):
|
||||
self.use_csrgraph=use_csrgraph
|
||||
self.w2v_model = None
|
||||
self._embeddings = {}
|
||||
|
||||
if self.use_csrgraph:
|
||||
node_names=list(graph.nodes())
|
||||
self.graph=csrgraph(graph,nodenames=node_names,threads=workers)
|
||||
|
||||
self.sentences = pd.DataFrame(self.graph.random_walks(
|
||||
epochs=num_walks, walklen=walk_length, return_weight=1.,neighbor_weight=1.))
|
||||
# Map nodeId -> node name
|
||||
node_dict = dict(zip(np.arange(len(node_names)), node_names))
|
||||
|
||||
for col in self.sentences.columns:
|
||||
self.sentences[col] = self.sentences[col].map(node_dict).astype(str)
|
||||
# Somehow gensim only trains on this list iterator
|
||||
# it silently mistrains on array input
|
||||
self.sentences = [list(x) for x in self.sentences.itertuples(False, None)]
|
||||
|
||||
else:
|
||||
self.graph = graph
|
||||
self.walker = RandomWalker(
|
||||
graph, p=1, q=1, )
|
||||
self.sentences = self.walker.simulate_walks(
|
||||
@ -36,20 +52,38 @@ class DeepWalk:
|
||||
|
||||
def train(self, embed_size=128, window_size=5, workers=3, iter=5, **kwargs):
|
||||
|
||||
|
||||
if self.use_csrgraph:
|
||||
kwargs["sentences"] = self.sentences
|
||||
kwargs["min_count"] = kwargs.get("min_count", 0)
|
||||
kwargs["size"] = embed_size
|
||||
kwargs["vector_size"] = embed_size
|
||||
kwargs["sg"] = 1 # skip gram
|
||||
kwargs["hs"] = 1 # deepwalk use Hierarchical Softmax
|
||||
kwargs["workers"] = workers
|
||||
kwargs["window"] = window_size
|
||||
kwargs["iter"] = iter
|
||||
kwargs["epochs"] = iter
|
||||
|
||||
print("Learning embedding vectors...")
|
||||
model = Word2Vec(**kwargs)
|
||||
print("Learning embedding vectors done!")
|
||||
self.w2v_model = model
|
||||
|
||||
else:
|
||||
kwargs["sentences"] = self.sentences
|
||||
kwargs["min_count"] = kwargs.get("min_count", 0)
|
||||
kwargs["vector_size"] = embed_size
|
||||
kwargs["sg"] = 1 # skip gram
|
||||
kwargs["hs"] = 1 # deepwalk use Hierarchical Softmax
|
||||
kwargs["workers"] = workers
|
||||
kwargs["window"] = window_size
|
||||
kwargs["epochs"] = iter
|
||||
|
||||
print("Learning embedding vectors...")
|
||||
model = Word2Vec(**kwargs)
|
||||
print("Learning embedding vectors done!")
|
||||
|
||||
self.w2v_model = model
|
||||
|
||||
return model
|
||||
|
||||
def get_embeddings(self,):
|
||||
|
Loading…
Reference in New Issue
Block a user