mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
fix pylint errors
This commit is contained in:
parent
8054cc6dda
commit
098b44ac51
@ -21,10 +21,20 @@ from adnc.analysis.plot_functionality import PlotFunctionality
|
||||
from adnc.analysis.prepare_variables import Bucket
|
||||
from adnc.model.utils import softmax
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
class Analyser():
|
||||
def __init__(self, data_set, record_dir, save_variables=False, save_fig=False):
|
||||
"""
|
||||
|
||||
Args:
|
||||
data_set:
|
||||
record_dir:
|
||||
save_variables:
|
||||
save_fig:
|
||||
"""
|
||||
self.data_set = data_set
|
||||
self.record_dir = record_dir
|
||||
self.save_variables = save_variables
|
||||
@ -72,7 +82,8 @@ class Analyser():
|
||||
|
||||
return self.estimate_memory_usage(variables)
|
||||
|
||||
def plot_analysis(self, variables, plot_dir, name='variables'):
|
||||
@staticmethod
|
||||
def plot_analysis(variables, plot_dir, name='variables'):
|
||||
|
||||
buck = Bucket(variables)
|
||||
plotter = PlotFunctionality(bucket=buck)
|
||||
@ -80,7 +91,8 @@ class Analyser():
|
||||
plotter.plot_basic_functionality(batch=0, plot_dir=plot_dir, name=name, show=True)
|
||||
plotter.plot_advanced_functionality(batch=0, plot_dir=plot_dir, name=name, show=True)
|
||||
|
||||
def estimate_memory_usage(self, variables):
|
||||
@staticmethod
|
||||
def estimate_memory_usage(variables):
|
||||
|
||||
analyse_values, prediction, decoded_predictions, data_sample, weights_dict = variables
|
||||
data, target, mask, x_word = data_sample
|
||||
|
@ -71,7 +71,8 @@ class bAbI():
|
||||
self.word_dict = word_dict
|
||||
self.re_word_dict = re_word_dict
|
||||
|
||||
def download_data(self, data_dir):
|
||||
@staticmethod
|
||||
def download_data(data_dir):
|
||||
|
||||
folder_name = 'tasks_1-20_v1-2'
|
||||
|
||||
|
@ -106,7 +106,8 @@ class ReadingComprehension():
|
||||
self.re_entity_dict = {v: k for k, v in self.entity_dict.items()}
|
||||
self.idx_word_dict = {v: k for k, v in self.word_idx_dict.items()}
|
||||
|
||||
def download_data(self, data_dir):
|
||||
@staticmethod
|
||||
def download_data(data_dir):
|
||||
|
||||
folder_name = 'cnn'
|
||||
|
||||
|
@ -123,7 +123,8 @@ class CopyTask():
|
||||
|
||||
return batch
|
||||
|
||||
def decode_output(self, sample, prediction):
|
||||
@staticmethod
|
||||
def decode_output(sample, prediction):
|
||||
if prediction.shape.__len__() == 3:
|
||||
prediction_decode_list = []
|
||||
target_decode_list = []
|
||||
|
@ -47,7 +47,7 @@ class BatchGenerator():
|
||||
self.order = self.data_set.rng.permutation(self.order)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
return next(self)
|
||||
|
||||
def __next__(self):
|
||||
|
||||
|
@ -163,14 +163,7 @@ class MANN():
|
||||
@property
|
||||
def parameter_amount(self):
|
||||
var_list = tf.trainable_variables()
|
||||
parameters = 0
|
||||
for variable in var_list:
|
||||
shape = variable.get_shape()
|
||||
variable_parametes = 1
|
||||
for dim in shape:
|
||||
variable_parametes *= dim.value
|
||||
parameters += variable_parametes
|
||||
return parameters
|
||||
return self.count_parameter_amount(var_list)
|
||||
|
||||
|
||||
def get_loss(self, prediction):
|
||||
|
@ -39,6 +39,10 @@ class BaseMemoryUnitCell():
|
||||
self.reuse = reuse
|
||||
self.name = name
|
||||
|
||||
self.const_memory_ones = None # will be defined with use of batch size in call method
|
||||
self.const_batch_memory_range = None # will be defined with use of batch size in call method
|
||||
self.const_link_matrix_inv_eye = None # will be defined with use of batch size in call method
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def state_size(self):
|
||||
|
@ -233,7 +233,8 @@ class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||
|
||||
return write_weighting
|
||||
|
||||
def _update_memory(self, pre_memory, write_weighting, write_vector, erase_vector):
|
||||
@staticmethod
|
||||
def _update_memory(pre_memory, write_weighting, write_vector, erase_vector):
|
||||
|
||||
write_w = tf.expand_dims(write_weighting, 3)
|
||||
erase_vector = tf.expand_dims(erase_vector, 2)
|
||||
|
@ -24,14 +24,15 @@ import yaml
|
||||
|
||||
|
||||
class color_code:
|
||||
bold = '\033[1m'
|
||||
underline = '\033[4m'
|
||||
blue = '\033[94m'
|
||||
darkcyan = '\033[36m'
|
||||
green = '\033[92m'
|
||||
red = '\033[91m'
|
||||
yellow = '\033[93m'
|
||||
end = '\033[0m'
|
||||
def __init__(self):
|
||||
self.bold = '\033[1m'
|
||||
self.underline = '\033[4m'
|
||||
self.blue = '\033[94m'
|
||||
self.darkcyan = '\033[36m'
|
||||
self.green = '\033[92m'
|
||||
self.red = '\033[91m'
|
||||
self.yellow = '\033[93m'
|
||||
self.end = '\033[0m'
|
||||
|
||||
|
||||
class Supporter():
|
||||
|
@ -41,6 +41,11 @@ class HolisticMultiRNNCell(RNNCell):
|
||||
"state_is_tuple is not set. State sizes are: %s"
|
||||
% str([c.state_size for c in self._cells]))
|
||||
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
sizes = [cell.output_size for cell in self._cells]
|
||||
return sum(sizes)
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
if self._state_is_tuple:
|
||||
@ -62,7 +67,7 @@ class HolisticMultiRNNCell(RNNCell):
|
||||
# presumably does not contain TensorArrays or anything else fancy
|
||||
return super(HolisticMultiRNNCell, self).zero_state(batch_size, dtype)
|
||||
|
||||
def call(self, inputs, state):
|
||||
def call(self, inputs, state, scope=None):
|
||||
"""Run this multi-layer cell on inputs, starting from state."""
|
||||
cur_state_pos = 0
|
||||
cur_inp = inputs
|
||||
|
@ -42,7 +42,8 @@ class WordEmbedding():
|
||||
embed = tf.nn.embedding_lookup(self.embeddings, word_idx, name='embedding_lookup')
|
||||
return embed
|
||||
|
||||
def initialize_random(self, vocabulary_size, embedding_size, dtype):
|
||||
@staticmethod
|
||||
def initialize_random(vocabulary_size, embedding_size, dtype):
|
||||
return tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0, dtype=dtype)
|
||||
|
||||
def initialize_with_glove(self, word_idx_dict, embedding_size, tmp_dir, dtype):
|
||||
|
Loading…
Reference in New Issue
Block a user