add vis API
This commit is contained in:
parent
11f610f2db
commit
05a6dd8565
BIN
log/viz.jpg
BIN
log/viz.jpg
Binary file not shown.
Before Width: | Height: | Size: 201 KiB |
72
src/vis.py
72
src/vis.py
@ -1,12 +1,19 @@
|
||||
import os
|
||||
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.tensorboard.plugins import projector
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter, conflict_handler='resolve')
|
||||
parser.add_argument('--label-file', default='data/cora/cora_label.txt',
|
||||
help='node label file')
|
||||
parser.add_argument('--emb-file', default='emb/unnamed_node_embs.txt',
|
||||
help='node embeddings file; suggest: data_method_dim_embs.txt')
|
||||
return parser.parse_args()
|
||||
|
||||
def read_node_label(filename):
|
||||
with open(filename, 'r') as f:
|
||||
node_label = {} # dict
|
||||
@ -25,44 +32,45 @@ def read_node_emb(filename):
|
||||
node_emb[int(vec[0])] = [float(i) for i in vec[1:]]
|
||||
return node_emb
|
||||
|
||||
def main(args):
|
||||
# --------load the node label and saved embeddings
|
||||
label_file = args.label_file
|
||||
emb_file = args.emb_file
|
||||
|
||||
# load the node label and saved embeddings
|
||||
label_file = './data/cora/cora_label.txt'
|
||||
emb_file = './emb/abrw.txt'
|
||||
label_dict = read_node_label(label_file)
|
||||
emb_dict = read_node_emb(emb_file)
|
||||
label_dict = read_node_label(label_file)
|
||||
emb_dict = read_node_emb(emb_file)
|
||||
|
||||
if label_dict.keys() != emb_dict.keys():
|
||||
print('ERROR, node ids are not matched! Plz check again')
|
||||
exit(0)
|
||||
if label_dict.keys() != emb_dict.keys():
|
||||
print('ERROR, node ids are not matched! Plz check again')
|
||||
exit(0)
|
||||
|
||||
# embeddings = np.array([i for i in emb_dict.values()], dtype=np.float32)
|
||||
embeddings = np.array([emb_dict[i] for i in sorted(emb_dict.keys(), reverse=False)], dtype=np.float32)
|
||||
|
||||
labels = [label_dict[i] for i in sorted(label_dict.keys(), reverse=False)]
|
||||
embeddings = np.array([emb_dict[i] for i in sorted(emb_dict.keys(), reverse=False)], dtype=np.float32)
|
||||
labels = [label_dict[i] for i in sorted(label_dict.keys(), reverse=False)]
|
||||
|
||||
|
||||
# save embeddings and labels
|
||||
emb_df = pd.DataFrame(embeddings)
|
||||
emb_df.to_csv('emb/log/embeddings.tsv', sep='\t', header=False, index=False)
|
||||
# --------save embeddings and labels
|
||||
emb_df = pd.DataFrame(embeddings)
|
||||
emb_df.to_csv('log/embeddings.tsv', sep='\t', header=False, index=False)
|
||||
|
||||
lab_df = pd.Series(labels, name='label')
|
||||
lab_df.to_frame().to_csv('emb/log/node_labels.tsv', header=False, index=False)
|
||||
lab_df = pd.Series(labels, name='label')
|
||||
lab_df.to_frame().to_csv('log/node_labels.tsv', header=False, index=False)
|
||||
|
||||
# save tf variable
|
||||
embeddings_var = tf.Variable(embeddings, name='embeddings')
|
||||
sess = tf.Session()
|
||||
# --------save tf variable
|
||||
embeddings_var = tf.Variable(embeddings, name='embeddings')
|
||||
sess = tf.Session()
|
||||
|
||||
saver = tf.train.Saver([embeddings_var])
|
||||
sess.run(embeddings_var.initializer)
|
||||
saver.save(sess, os.path.join('emb/log', "model.ckpt"), 1)
|
||||
saver = tf.train.Saver([embeddings_var])
|
||||
sess.run(embeddings_var.initializer)
|
||||
saver.save(sess, os.path.join('log', "model.ckpt"), 1)
|
||||
|
||||
# configure tf projector
|
||||
config = projector.ProjectorConfig()
|
||||
embedding = config.embeddings.add()
|
||||
embedding.tensor_name = 'embeddings'
|
||||
embedding.metadata_path = 'node_labels.tsv'
|
||||
# --------configure tf projector
|
||||
config = projector.ProjectorConfig()
|
||||
embedding = config.embeddings.add()
|
||||
embedding.tensor_name = 'embeddings'
|
||||
embedding.metadata_path = 'node_labels.tsv'
|
||||
|
||||
projector.visualize_embeddings(tf.summary.FileWriter('emb/log'), config)
|
||||
projector.visualize_embeddings(tf.summary.FileWriter('log'), config)
|
||||
|
||||
# type "tensorboard --logdir=emb/log" in CMD and have fun :)
|
||||
if __name__ == '__main__':
|
||||
main(parse_args())
|
||||
print('Run "tensorboard --logdir=log" in CMD and then, copy the given address to web browser')
|
||||
|
Loading…
Reference in New Issue
Block a user