graphsageAPI as a class & add save emb method
This commit is contained in:
parent
ce0ce07d4c
commit
fa2fd1f21f
1
.gitignore
vendored
1
.gitignore
vendored
@ -18,6 +18,7 @@ db.ini
|
||||
deploy_key_rsa
|
||||
log
|
||||
bash
|
||||
emb
|
||||
|
||||
|
||||
#zeyu--------------------------------
|
||||
|
@ -1,34 +1,33 @@
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
#default parameters
|
||||
''' global parameters for graphsage models
|
||||
tune these parameters here if needed
|
||||
if needed use: from libnrl.graphsage.__init__ import *
|
||||
'''
|
||||
|
||||
#seed = 2018
|
||||
#np.random.seed(seed)
|
||||
#tf.set_random_seed(seed)
|
||||
log_device_placement = False
|
||||
|
||||
|
||||
# follow the orignal code by the paper author https://github.com/williamleif/GraphSAGE
|
||||
# we follow the opt parameters given by papers GCN and graphSAGE
|
||||
# note: citeseer+pubmed all follow the same parameters as cora, see their papers)
|
||||
# tensorflow + Adam optimizer + Random weight init + row norm of attr
|
||||
|
||||
epochs = 100
|
||||
dim_1 = 64 #dim = dim1+dim2 = 128
|
||||
dim_1 = 64 #dim = dim1+dim2 = 128 for sage-mean and sage-gcn
|
||||
dim_2 = 64
|
||||
samples_1 = 25
|
||||
samples_2 = 10
|
||||
learning_rate = 0.001
|
||||
dropout = 0.5
|
||||
weight_decay = 0.0001
|
||||
learning_rate = 0.0001
|
||||
batch_size = 128 #if run out of memory, try to reduce them, but we use the default e.g. 64, default=512
|
||||
normalize = True #row norm of node attributes/features
|
||||
samples_1 = 25
|
||||
samples_2 = 10
|
||||
|
||||
|
||||
#other parameters that paper did not mentioned, but we also follow the defaults https://github.com/williamleif/GraphSAGE
|
||||
model_size = 'small'
|
||||
max_degree = 100
|
||||
neg_sample_size = 20
|
||||
|
||||
random_context= True
|
||||
validate_batch_size = 64 #if run out of memory, try to reduce them, but we use the default e.g. 64, default=256
|
||||
validate_iter = 5000
|
||||
|
@ -1,112 +1,120 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
# author: Chengbin Hou @ SUSTech 2018 \n
|
||||
# to tune parameters, refer to graphsage->__init__.py \n
|
||||
|
||||
# we provide utils to transform the orignal data into graphSAGE format \n
|
||||
# the APIs are designed for unsupervised, \n
|
||||
# for supervised way, plz refer and complete 'to do...' \n
|
||||
# currently only support 'mean' and 'gcn' model \n
|
||||
'''
|
||||
#-----------------------------------------------------------------------------
|
||||
# author: Chengbin Hou @ SUSTech 2018
|
||||
# Email: Chengbin.Hou10@foxmail.com
|
||||
# we provide utils to transform the orignal data into graphSAGE format
|
||||
# you may easily use these APIs as what we demostrated in main.py of OpenANE
|
||||
# the APIs are designed for unsupervised, for supervised way, plz complete 'label' to do codes...
|
||||
#-----------------------------------------------------------------------------
|
||||
'''
|
||||
|
||||
from networkx.readwrite import json_graph
|
||||
import json
|
||||
import random
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from libnrl.graphsage import unsupervised_train
|
||||
from libnrl.graphsage.__init__ import * #import default parameters
|
||||
|
||||
def add_train_val_test_to_G(graph, test_perc=0.0, val_perc=0.1): #due to unsupervised, we do not need test data
|
||||
G = graph.G #take out nx G
|
||||
random.seed(2018)
|
||||
num_nodes = nx.number_of_nodes(G)
|
||||
test_ind = random.sample(range(0, num_nodes), int(num_nodes*test_perc))
|
||||
val_ind = random.sample(range(0, num_nodes), int(num_nodes*val_perc))
|
||||
for ind in range(0, num_nodes):
|
||||
id = graph.look_back_list[ind]
|
||||
if ind in test_ind:
|
||||
G.nodes[id]['test'] = True
|
||||
G.nodes[id]['val'] = False
|
||||
elif ind in val_ind:
|
||||
G.nodes[id]['test'] = False
|
||||
G.nodes[id]['val'] = True
|
||||
class graphSAGE(object):
|
||||
def __init__(self, graph, sage_model='mean', is_supervised=False):
|
||||
self.graph = graph
|
||||
self.normalize = True #row normalization of node attributes
|
||||
self.num_walks = 50
|
||||
self.walk_len = 5
|
||||
|
||||
self.add_train_val_test_to_G(test_perc=0.0, val_perc=0.1) #if unsupervised, no test data
|
||||
train_data = self.tranform_data_for_graphsage() #obtain graphSAGE required training data
|
||||
|
||||
self.vectors = None
|
||||
if not is_supervised:
|
||||
from libnrl.graphsage import unsupervised_train
|
||||
self.vectors = unsupervised_train.train(train_data=train_data, test_data=None, model=sage_model)
|
||||
else:
|
||||
G.nodes[id]['test'] = False
|
||||
G.nodes[id]['val'] = False
|
||||
|
||||
## Make sure the graph has edge train_removed annotations
|
||||
## (some datasets might already have this..)
|
||||
print("Loaded data.. now preprocessing..")
|
||||
for edge in G.edges():
|
||||
if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or
|
||||
G.node[edge[0]]['test'] or G.node[edge[1]]['test']):
|
||||
G[edge[0]][edge[1]]['train_removed'] = True
|
||||
else:
|
||||
G[edge[0]][edge[1]]['train_removed'] = False
|
||||
return G
|
||||
#to do...
|
||||
#from libnrl.graphsage import supervised_train
|
||||
#self.vectors = supervised_train.train()
|
||||
pass
|
||||
|
||||
def run_random_walks(G, num_walks=50, walk_len=5):
|
||||
nodes = [n for n in G.nodes() if not G.node[n]["val"] and not G.node[n]["test"]]
|
||||
G = G.subgraph(nodes)
|
||||
pairs = []
|
||||
for count, node in enumerate(nodes):
|
||||
if G.degree(node) == 0:
|
||||
continue
|
||||
for i in range(num_walks):
|
||||
curr_node = node
|
||||
for j in range(walk_len):
|
||||
if len(list(G.neighbors(curr_node))) == 0: #isolated nodes! often appeared in real-world
|
||||
break
|
||||
next_node = random.choice(list(G.neighbors(curr_node))) #changed due to compatibility
|
||||
#next_node = random.choice(G.neighbors(curr_node))
|
||||
# self co-occurrences are useless
|
||||
if curr_node != node:
|
||||
pairs.append((node,curr_node))
|
||||
curr_node = next_node
|
||||
if count % 1000 == 0:
|
||||
print("Done walks for", count, "nodes")
|
||||
return pairs
|
||||
|
||||
def tranform_data_for_graphsage(graph):
|
||||
G = add_train_val_test_to_G(graph) #given OpenANE graph --> obtain graphSAGE graph
|
||||
#G_json = json_graph.node_link_data(G) #train_data[0] in unsupervised_train.py
|
||||
def add_train_val_test_to_G(self, test_perc=0.0, val_perc=0.1):
|
||||
''' add if 'val' and/or 'test' to each node in G '''
|
||||
G = self.graph.G
|
||||
num_nodes = nx.number_of_nodes(G)
|
||||
test_ind = random.sample(range(0, num_nodes), int(num_nodes*test_perc))
|
||||
val_ind = random.sample(range(0, num_nodes), int(num_nodes*val_perc))
|
||||
for ind in range(0, num_nodes):
|
||||
id = self.graph.look_back_list[ind]
|
||||
if ind in test_ind:
|
||||
G.nodes[id]['test'] = True
|
||||
G.nodes[id]['val'] = False
|
||||
elif ind in val_ind:
|
||||
G.nodes[id]['test'] = False
|
||||
G.nodes[id]['val'] = True
|
||||
else:
|
||||
G.nodes[id]['test'] = False
|
||||
G.nodes[id]['val'] = False
|
||||
## Make sure the graph has edge train_removed annotations
|
||||
## (some datasets might already have this..)
|
||||
print("Loaded data.. now preprocessing..")
|
||||
for edge in G.edges():
|
||||
if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or
|
||||
G.node[edge[0]]['test'] or G.node[edge[1]]['test']):
|
||||
G[edge[0]][edge[1]]['train_removed'] = True
|
||||
else:
|
||||
G[edge[0]][edge[1]]['train_removed'] = False
|
||||
return G
|
||||
|
||||
id_map = graph.look_up_dict
|
||||
#conversion = lambda n : int(n) # compatible with networkx >2.0
|
||||
#id_map = {conversion(k):int(v) for k,v in id_map.items()} # due to graphSAGE requirement
|
||||
def tranform_data_for_graphsage(self):
|
||||
''' OpenANE graph -> graphSAGE required format '''
|
||||
id_map = self.graph.look_up_dict
|
||||
G = self.graph.G
|
||||
feats = np.array([G.nodes[id]['attr'] for id in id_map.keys()])
|
||||
normalize = self.normalize
|
||||
if normalize and not feats is None:
|
||||
print("------------- row norm of node attributes ------------------", normalize)
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
train_inds = [id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']]
|
||||
train_feats = feats[train_inds]
|
||||
scaler = StandardScaler()
|
||||
scaler.fit(train_feats)
|
||||
feats = scaler.transform(feats)
|
||||
#feats1 = nx.get_node_attributes(G,'test')
|
||||
#feats2 = nx.get_node_attributes(G,'val')
|
||||
walks = []
|
||||
walks = self.run_random_walks(num_walks=self.num_walks, walk_len=self.walk_len)
|
||||
class_map = 0 #to do... use sklearn to make class into binary form, no need for unsupervised...
|
||||
return G, feats, id_map, walks, class_map
|
||||
|
||||
feats = np.array([G.nodes[id]['attr'] for id in id_map.keys()])
|
||||
normalize = True #have decleared in __init__.py
|
||||
if normalize and not feats is None:
|
||||
print("-------------row norm of node attributes/features------------------")
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
train_inds = [id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']]
|
||||
train_feats = feats[train_inds]
|
||||
scaler = StandardScaler()
|
||||
scaler.fit(train_feats)
|
||||
feats = scaler.transform(feats)
|
||||
#feats1 = nx.get_node_attributes(G,'test')
|
||||
#feats2 = nx.get_node_attributes(G,'val')
|
||||
def run_random_walks(self, num_walks=50, walk_len=5):
|
||||
''' generate random walks '''
|
||||
G = self.graph.G
|
||||
nodes = [n for n in G.nodes() if not G.node[n]["val"] and not G.node[n]["test"]]
|
||||
G = G.subgraph(nodes)
|
||||
pairs = []
|
||||
for count, node in enumerate(nodes):
|
||||
if G.degree(node) == 0:
|
||||
continue
|
||||
for i in range(num_walks):
|
||||
curr_node = node
|
||||
for j in range(walk_len):
|
||||
if len(list(G.neighbors(curr_node))) == 0: #isolated nodes! often appeared in real-world
|
||||
break
|
||||
next_node = random.choice(list(G.neighbors(curr_node))) #changed due to compatibility
|
||||
#next_node = random.choice(G.neighbors(curr_node))
|
||||
# self co-occurrences are useless
|
||||
if curr_node != node:
|
||||
pairs.append((node,curr_node))
|
||||
curr_node = next_node
|
||||
if count % 1000 == 0:
|
||||
print("Done walks for", count, "nodes")
|
||||
return pairs
|
||||
|
||||
walks = []
|
||||
walks = run_random_walks(G, num_walks=50, walk_len=5) #use the defualt parameter in graphSAGE
|
||||
|
||||
class_map = 0 #to do... use sklearn to make class into binary form, no need for unsupervised...
|
||||
return G, feats, id_map, walks, class_map
|
||||
|
||||
def graphsage_unsupervised_train(graph, graphsage_model = 'graphsage_mean'):
|
||||
train_data = tranform_data_for_graphsage(graph)
|
||||
#from unsupervised_train.py
|
||||
vectors = unsupervised_train.train(train_data, test_data=None, model = graphsage_model)
|
||||
return vectors
|
||||
|
||||
'''
|
||||
def save_embeddings(self, filename):
|
||||
fout = open(filename, 'w')
|
||||
node_num = len(self.vectors.keys())
|
||||
fout.write("{} {}\n".format(node_num, self.size))
|
||||
for node, vec in self.vectors.items():
|
||||
fout.write("{} {}\n".format(node,
|
||||
' '.join([str(x) for x in vec])))
|
||||
fout.close()
|
||||
'''
|
||||
def save_embeddings(self, filename):
|
||||
''' save embeddings to file '''
|
||||
fout = open(filename, 'w')
|
||||
node_num = len(self.vectors.keys())
|
||||
emb_dim = len(next(iter(self.vectors.values())))
|
||||
fout.write("{} {}\n".format(node_num, emb_dim))
|
||||
for node, vec in self.vectors.items():
|
||||
fout.write("{} {}\n".format(node,' '.join([str(x) for x in vec])))
|
||||
fout.close()
|
@ -85,11 +85,10 @@ def construct_placeholders():
|
||||
}
|
||||
return placeholders
|
||||
|
||||
|
||||
def train(train_data, test_data=None, model='graphsage_mean'):
|
||||
def train(train_data, test_data, model):
|
||||
print('---------- the graphsage model we used: ', model)
|
||||
print('---------- parameters we sued: epochs, dim_1+dim_2, samples_1, samples_2, dropout, weight_decay, learning_rate, batch_size, normalize',
|
||||
epochs, dim_1+dim_2, samples_1, samples_2, dropout, weight_decay, learning_rate, batch_size, normalize)
|
||||
print('---------- parameters we sued: epochs, dim_1+dim_2, samples_1, samples_2, dropout, weight_decay, learning_rate, batch_size',
|
||||
epochs, dim_1+dim_2, samples_1, samples_2, dropout, weight_decay, learning_rate, batch_size)
|
||||
G = train_data[0]
|
||||
features = train_data[1] #note: features are in order of graph.look_up_list, since id_map = {k: v for v, k in enumerate(graph.look_back_list)}
|
||||
id_map = train_data[2]
|
||||
@ -110,7 +109,7 @@ def train(train_data, test_data=None, model='graphsage_mean'):
|
||||
adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
|
||||
adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
|
||||
|
||||
if model == 'graphsage_mean':
|
||||
if model == 'mean':
|
||||
# Create model
|
||||
sampler = UniformNeighborSampler(adj_info)
|
||||
layer_infos = [SAGEInfo("node", sampler, samples_1, dim_1),
|
||||
@ -141,7 +140,7 @@ def train(train_data, test_data=None, model='graphsage_mean'):
|
||||
concat=False,
|
||||
logging=True)
|
||||
|
||||
elif model == 'graphsage_seq': #LSTM as stated in paper? very slow anyway...
|
||||
elif model == 'seq': #LSTM as stated in paper? very slow anyway...
|
||||
sampler = UniformNeighborSampler(adj_info)
|
||||
layer_infos = [SAGEInfo("node", sampler, samples_1, dim_1),
|
||||
SAGEInfo("node", sampler, samples_2, dim_2)]
|
||||
@ -156,7 +155,7 @@ def train(train_data, test_data=None, model='graphsage_mean'):
|
||||
model_size=model_size,
|
||||
logging=True)
|
||||
|
||||
elif model == 'graphsage_maxpool':
|
||||
elif model == 'maxpool':
|
||||
sampler = UniformNeighborSampler(adj_info)
|
||||
layer_infos = [SAGEInfo("node", sampler, samples_1, dim_1),
|
||||
SAGEInfo("node", sampler, samples_2, dim_2)]
|
||||
@ -170,7 +169,7 @@ def train(train_data, test_data=None, model='graphsage_mean'):
|
||||
model_size=model_size,
|
||||
identity_dim = identity_dim,
|
||||
logging=True)
|
||||
elif model == 'graphsage_meanpool':
|
||||
elif model == 'meanpool':
|
||||
sampler = UniformNeighborSampler(adj_info)
|
||||
layer_infos = [SAGEInfo("node", sampler, samples_1, dim_1),
|
||||
SAGEInfo("node", sampler, samples_2, dim_2)]
|
||||
|
84
src/main.py
84
src/main.py
@ -5,7 +5,7 @@ STEP2: prepare data -->
|
||||
STEP3: learn node embeddings -->
|
||||
STEP4: downstream evaluations
|
||||
|
||||
python src/main.py --method abrw --save-emb False
|
||||
python src/main.py --method abrw
|
||||
|
||||
by Chengbin Hou 2018 <chengbin.hou10@foxmail.com>
|
||||
'''
|
||||
@ -65,8 +65,9 @@ def parse_args():
|
||||
parser.add_argument('--emb-file', default='emb/unnamed_node_embs.txt',
|
||||
help='node embeddings file; suggest: data_method_dim_embs.txt')
|
||||
#-------------------------------------------------method settings-----------------------------------------------------------
|
||||
parser.add_argument('--method', default='abrw', choices=['node2vec', 'deepwalk', 'line', 'gcn', 'grarep', 'tadw',
|
||||
'abrw', 'asne', 'aane', 'attrpure', 'attrcomb', 'graphsage'],
|
||||
parser.add_argument('--method', default='abrw', choices=['deepwalk', 'node2vec', 'line', 'grarep',
|
||||
'abrw', 'attrpure', 'attrcomb', 'tadw', 'aane',
|
||||
'sagemean','sagegcn', 'gcn', 'asne'],
|
||||
help='choices of Network Embedding methods')
|
||||
parser.add_argument('--ABRW-topk', default=30, type=int,
|
||||
help='select the most attr similar top k nodes of a node; ranging [0, # of nodes]')
|
||||
@ -134,8 +135,8 @@ def main(args):
|
||||
elif args.graph_format == 'edgelist':
|
||||
g.read_edgelist(path=args.graph_file, weighted=args.weighted, directed=args.directed)
|
||||
#load node attribute info------
|
||||
is_ane = (args.method == 'abrw' or args.method == 'tadw' or args.method == 'gcn' or args.method == 'graphsage' or
|
||||
args.method == 'attrpure' or args.method == 'attrcomb' or args.method == 'asne' or args.method == 'aane')
|
||||
is_ane = (args.method == 'abrw' or args.method == 'tadw' or args.method == 'gcn' or args.method == 'sagemean' or args.method == 'sagegcn' or
|
||||
args.method == 'attrpure' or args.method == 'attrcomb' or args.method == 'asne' or args.method == 'aane')
|
||||
if is_ane:
|
||||
assert args.attribute_file != ''
|
||||
g.read_node_attr(args.attribute_file)
|
||||
@ -188,6 +189,12 @@ def main(args):
|
||||
elif args.method == 'line': #if auto_save, use label to justifiy the best embeddings by looking at micro / macro-F1 score
|
||||
model = line.LINE(graph=g, epoch = args.epochs, rep_size=args.dim, order=args.LINE_order, batch_size=args.batch_size, negative_ratio=args.LINE_negative_ratio,
|
||||
label_file=args.label_file, clf_ratio=args.label_reserved, auto_save=True, best='micro')
|
||||
|
||||
elif args.method == 'sagemean': #other choices: graphsage_seq, graphsage_maxpool, graphsage_meanpool, n2v
|
||||
model = graphsageAPI.graphSAGE(graph=g, sage_model='mean', is_supervised=False)
|
||||
elif args.method == 'sagegcn': #parameters for graphsage models are in 'graphsage' -> '__init__.py'
|
||||
model = graphsageAPI.graphSAGE(graph=g, sage_model='gcn', is_supervised=False)
|
||||
|
||||
elif args.method == 'asne':
|
||||
if args.task == 'nc':
|
||||
model = asne.ASNE(graph=g, dim=args.dim, alpha=args.ASNE_lamb, epoch=args.epochs, learning_rate=args.learning_rate, batch_size=args.batch_size,
|
||||
@ -195,53 +202,46 @@ def main(args):
|
||||
else:
|
||||
model = asne.ASNE(graph=g, dim=args.dim, alpha=args.ASNE_lamb, epoch=args.epochs, learning_rate=args.learning_rate, batch_size=args.batch_size,
|
||||
X_test=test_node_pairs, Y_test=test_edge_labels, task=args.task, nc_ratio=args.label_reserved, lp_ratio=args.link_reserved, label_file=args.label_file)
|
||||
elif args.method == 'graphsage': #we follow the default parameters, see __inti__.py in graphsage file
|
||||
model = graphsageAPI.graphsage_unsupervised_train(graph=g, graphsage_model = 'graphsage_mean')
|
||||
elif args.method == 'gcn':
|
||||
model = graphsageAPI.graphsage_unsupervised_train(graph=g, graphsage_model = 'gcn') #graphsage-gcn
|
||||
else:
|
||||
print('no method was found...')
|
||||
print('method not found...')
|
||||
exit(0)
|
||||
'''
|
||||
elif args.method == 'gcn': #OR use graphsage-gcn as in graphsage method...
|
||||
assert args.label_file != '' #must have node label
|
||||
assert args.feature_file != '' #different from previous ANE methods
|
||||
g.read_node_label(args.label_file) #gcn is an end-to-end supervised ANE methoed
|
||||
model = gcnAPI.GCN(graph=g, dropout=args.dropout,
|
||||
weight_decay=args.weight_decay, hidden1=args.hidden,
|
||||
epochs=args.epochs, clf_ratio=args.label_reserved)
|
||||
#gcn does not have model.save_embeddings() func
|
||||
'''
|
||||
t2 = time.time()
|
||||
print(f'STEP3: end learning embeddings; time cost: {(t2-t1):.2f}s')
|
||||
|
||||
if args.save_emb:
|
||||
model.save_embeddings(args.emb_file + time.strftime(' %Y%m%d-%H%M%S', time.localtime()))
|
||||
print(f'Save node embeddings in file: {args.emb_file}')
|
||||
t2 = time.time()
|
||||
print(f'STEP3: end learning embeddings; time cost: {(t2-t1):.2f}s')
|
||||
|
||||
'''
|
||||
#to do.... semi-supervised methods: gcn, graphsage, etc...
|
||||
if args.method == 'gcn': #semi-supervised gcn
|
||||
assert args.label_file != '' #must have node label
|
||||
assert args.feature_file != '' #different from previous ANE methods
|
||||
g.read_node_label(args.label_file) #gcn is an end-to-end supervised ANE methoed
|
||||
model = gcnAPI.GCN(graph=g, dropout=args.dropout, weight_decay=args.weight_decay, hidden1=args.hidden, epochs=args.epochs, clf_ratio=args.label_reserved)
|
||||
print('semi-supervsied method, no embs, exit the program...') #semi-supervised gcn do not produce embs
|
||||
exit(0)
|
||||
'''
|
||||
|
||||
|
||||
#---------------------------------------STEP4: downstream task-----------------------------------------------
|
||||
print('\nSTEP4: start evaluating ......: ')
|
||||
t1 = time.time()
|
||||
if args.method != 'semi_supervised_gcn': #except semi-supervised methods, we will get emb first, and then eval emb
|
||||
vectors = 0
|
||||
if args.method == 'graphsage' or args.method == 'gcn': #to do... run without this 'if'
|
||||
vectors = model
|
||||
else:
|
||||
vectors = model.vectors #for other methods....
|
||||
del model, g
|
||||
#------lp task
|
||||
if args.task == 'lp' or args.task == 'lp_and_nc':
|
||||
#X_test_lp, Y_test_lp = read_edge_label(args.label_file) #if you want to load your own lp testing data
|
||||
print(f'Link Prediction task; the percentage of positive links for testing: {(args.link_remove*100):.2f}%'
|
||||
+ ' (by default, also generate equal negative links for testing)')
|
||||
clf = lpClassifier(vectors=vectors) #similarity/distance metric as clf; basically, lp is a binary clf probelm
|
||||
clf.evaluate(test_node_pairs, test_edge_labels)
|
||||
#------nc task
|
||||
if args.task == 'nc' or args.task == 'lp_and_nc':
|
||||
X, Y = read_node_label(args.label_file)
|
||||
print(f'Node Classification task; the percentage of labels for testing: {((1-args.label_reserved)*100):.2f}%')
|
||||
clf = ncClassifier(vectors=vectors, clf=LogisticRegression()) #use Logistic Regression as clf; we may choose SVM or more advanced ones
|
||||
clf.split_train_evaluate(X, Y, args.label_reserved)
|
||||
vectors = model.vectors
|
||||
del model, g
|
||||
#------lp task
|
||||
if args.task == 'lp' or args.task == 'lp_and_nc':
|
||||
#X_test_lp, Y_test_lp = read_edge_label(args.label_file) #if you want to load your own lp testing data
|
||||
print(f'Link Prediction task; the percentage of positive links for testing: {(args.link_remove*100):.2f}%'
|
||||
+ ' (by default, also generate equal negative links for testing)')
|
||||
clf = lpClassifier(vectors=vectors) #similarity/distance metric as clf; basically, lp is a binary clf probelm
|
||||
clf.evaluate(test_node_pairs, test_edge_labels)
|
||||
#------nc task
|
||||
if args.task == 'nc' or args.task == 'lp_and_nc':
|
||||
X, Y = read_node_label(args.label_file)
|
||||
print(f'Node Classification task; the percentage of labels for testing: {((1-args.label_reserved)*100):.2f}%')
|
||||
clf = ncClassifier(vectors=vectors, clf=LogisticRegression()) #use Logistic Regression as clf; we may choose SVM or more advanced ones
|
||||
clf.split_train_evaluate(X, Y, args.label_reserved)
|
||||
t2 = time.time()
|
||||
print(f'STEP4: end evaluating; time cost: {(t2-t1):.2f}s')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user