From 7e868ea8899dcfe1c1db2e6ae78737d2bbdb6231 Mon Sep 17 00:00:00 2001 From: Chengbin HOU <31309465+houchengbin@users.noreply.github.com> Date: Mon, 15 Apr 2019 11:44:09 +0100 Subject: [PATCH] update ASNE method (IEEE TKDE2018) --- src/libnrl/asne.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/libnrl/asne.py b/src/libnrl/asne.py index c83f650..2b1211d 100644 --- a/src/libnrl/asne.py +++ b/src/libnrl/asne.py @@ -68,20 +68,22 @@ class ASNE(BaseEstimator, TransformerMixin): self.attr_embed = tf.matmul(self.train_data_attr, self.weights['attr_embeddings']) # batch_size * attr_dim self.embed_layer = tf.concat([self.id_embed, self.alpha * self.attr_embed], 1) # batch_size * (id_dim + attr_dim) #an error due to old tf! - ''' + ## can add hidden_layers component here!---------------------------------- #0) no hidden layer #1) 128 - #2) 256+128 ##--------paper stated it used two hidden layers with softsign + #2) 256+128 #3) 512+256+128 + # Note: according to the Fig 5 in paper https://ieeexplore.ieee.org/abstract/document/8326519 + # here we follow 2) i.e. 256 softsign + 128 softsign len_h1_in = self.id_embedding_size + self.attr_embedding_size - len_h1_out = 256 #or self.id_embedding_size + self.attr_embedding_size # if only add h1 + len_h1_out = 256 # 256 softsign len_h2_in = len_h1_out - len_h2_out = self.id_embedding_size + self.attr_embedding_size + len_h2_out = self.id_embedding_size + self.attr_embedding_size # 128 softsign i.e. dim of embeddings self.h1 = add_layer(inputs=self.embed_layer, in_size=len_h1_in, out_size=len_h1_out, activation_function=tf.nn.softsign) self.h2 = add_layer(inputs=self.h1, in_size=len_h2_in, out_size=len_h2_out, activation_function=tf.nn.softsign) ## ------------------------------------------------------------------------- - ''' + # Compute the loss, using a sample of the negative labels each time. self.loss = tf.reduce_mean(tf.nn.sampled_softmax_loss(weights=self.weights['out_embeddings'], biases=self.weights['biases'], # if one needs to change layers