add vis API

This commit is contained in:
Chengbin Hou 2018-11-30 21:36:50 +00:00
parent 11f610f2db
commit 05a6dd8565
2 changed files with 40 additions and 32 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 201 KiB

View File

@ -1,12 +1,19 @@
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import numpy as np
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector
import pandas as pd
def parse_args():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter, conflict_handler='resolve')
parser.add_argument('--label-file', default='data/cora/cora_label.txt',
help='node label file')
parser.add_argument('--emb-file', default='emb/unnamed_node_embs.txt',
help='node embeddings file; suggest: data_method_dim_embs.txt')
return parser.parse_args()
def read_node_label(filename):
with open(filename, 'r') as f:
node_label = {} # dict
@ -25,44 +32,45 @@ def read_node_emb(filename):
node_emb[int(vec[0])] = [float(i) for i in vec[1:]]
return node_emb
def main(args):
# --------load the node label and saved embeddings
label_file = args.label_file
emb_file = args.emb_file
# load the node label and saved embeddings
label_file = './data/cora/cora_label.txt'
emb_file = './emb/abrw.txt'
label_dict = read_node_label(label_file)
emb_dict = read_node_emb(emb_file)
label_dict = read_node_label(label_file)
emb_dict = read_node_emb(emb_file)
if label_dict.keys() != emb_dict.keys():
print('ERROR, node ids are not matched! Plz check again')
exit(0)
if label_dict.keys() != emb_dict.keys():
print('ERROR, node ids are not matched! Plz check again')
exit(0)
# embeddings = np.array([i for i in emb_dict.values()], dtype=np.float32)
embeddings = np.array([emb_dict[i] for i in sorted(emb_dict.keys(), reverse=False)], dtype=np.float32)
labels = [label_dict[i] for i in sorted(label_dict.keys(), reverse=False)]
embeddings = np.array([emb_dict[i] for i in sorted(emb_dict.keys(), reverse=False)], dtype=np.float32)
labels = [label_dict[i] for i in sorted(label_dict.keys(), reverse=False)]
# save embeddings and labels
emb_df = pd.DataFrame(embeddings)
emb_df.to_csv('emb/log/embeddings.tsv', sep='\t', header=False, index=False)
# --------save embeddings and labels
emb_df = pd.DataFrame(embeddings)
emb_df.to_csv('log/embeddings.tsv', sep='\t', header=False, index=False)
lab_df = pd.Series(labels, name='label')
lab_df.to_frame().to_csv('emb/log/node_labels.tsv', header=False, index=False)
lab_df = pd.Series(labels, name='label')
lab_df.to_frame().to_csv('log/node_labels.tsv', header=False, index=False)
# save tf variable
embeddings_var = tf.Variable(embeddings, name='embeddings')
sess = tf.Session()
# --------save tf variable
embeddings_var = tf.Variable(embeddings, name='embeddings')
sess = tf.Session()
saver = tf.train.Saver([embeddings_var])
sess.run(embeddings_var.initializer)
saver.save(sess, os.path.join('emb/log', "model.ckpt"), 1)
saver = tf.train.Saver([embeddings_var])
sess.run(embeddings_var.initializer)
saver.save(sess, os.path.join('log', "model.ckpt"), 1)
# configure tf projector
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = 'embeddings'
embedding.metadata_path = 'node_labels.tsv'
# --------configure tf projector
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = 'embeddings'
embedding.metadata_path = 'node_labels.tsv'
projector.visualize_embeddings(tf.summary.FileWriter('emb/log'), config)
projector.visualize_embeddings(tf.summary.FileWriter('log'), config)
# type "tensorboard --logdir=emb/log" in CMD and have fun :)
if __name__ == '__main__':
main(parse_args())
print('Run "tensorboard --logdir=log" in CMD and then, copy the given address to web browser')