fix pylint errors

This commit is contained in:
joergfranke 2018-07-09 15:35:13 +02:00
parent 8054cc6dda
commit 098b44ac51
11 changed files with 45 additions and 25 deletions

View File

@ -21,10 +21,20 @@ from adnc.analysis.plot_functionality import PlotFunctionality
from adnc.analysis.prepare_variables import Bucket
from adnc.model.utils import softmax
"""
"""
class Analyser():
def __init__(self, data_set, record_dir, save_variables=False, save_fig=False):
"""
Args:
data_set:
record_dir:
save_variables:
save_fig:
"""
self.data_set = data_set
self.record_dir = record_dir
self.save_variables = save_variables
@ -72,7 +82,8 @@ class Analyser():
return self.estimate_memory_usage(variables)
def plot_analysis(self, variables, plot_dir, name='variables'):
@staticmethod
def plot_analysis(variables, plot_dir, name='variables'):
buck = Bucket(variables)
plotter = PlotFunctionality(bucket=buck)
@ -80,7 +91,8 @@ class Analyser():
plotter.plot_basic_functionality(batch=0, plot_dir=plot_dir, name=name, show=True)
plotter.plot_advanced_functionality(batch=0, plot_dir=plot_dir, name=name, show=True)
def estimate_memory_usage(self, variables):
@staticmethod
def estimate_memory_usage(variables):
analyse_values, prediction, decoded_predictions, data_sample, weights_dict = variables
data, target, mask, x_word = data_sample

View File

@ -71,7 +71,8 @@ class bAbI():
self.word_dict = word_dict
self.re_word_dict = re_word_dict
def download_data(self, data_dir):
@staticmethod
def download_data(data_dir):
folder_name = 'tasks_1-20_v1-2'

View File

@ -106,7 +106,8 @@ class ReadingComprehension():
self.re_entity_dict = {v: k for k, v in self.entity_dict.items()}
self.idx_word_dict = {v: k for k, v in self.word_idx_dict.items()}
def download_data(self, data_dir):
@staticmethod
def download_data(data_dir):
folder_name = 'cnn'

View File

@ -123,7 +123,8 @@ class CopyTask():
return batch
def decode_output(self, sample, prediction):
@staticmethod
def decode_output(sample, prediction):
if prediction.shape.__len__() == 3:
prediction_decode_list = []
target_decode_list = []

View File

@ -47,7 +47,7 @@ class BatchGenerator():
self.order = self.data_set.rng.permutation(self.order)
def __iter__(self):
return self
return next(self)
def __next__(self):

View File

@ -163,14 +163,7 @@ class MANN():
@property
def parameter_amount(self):
var_list = tf.trainable_variables()
parameters = 0
for variable in var_list:
shape = variable.get_shape()
variable_parametes = 1
for dim in shape:
variable_parametes *= dim.value
parameters += variable_parametes
return parameters
return self.count_parameter_amount(var_list)
def get_loss(self, prediction):

View File

@ -39,6 +39,10 @@ class BaseMemoryUnitCell():
self.reuse = reuse
self.name = name
self.const_memory_ones = None # will be defined with use of batch size in call method
self.const_batch_memory_range = None # will be defined with use of batch size in call method
self.const_link_matrix_inv_eye = None # will be defined with use of batch size in call method
@property
@abstractmethod
def state_size(self):

View File

@ -233,7 +233,8 @@ class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
return write_weighting
def _update_memory(self, pre_memory, write_weighting, write_vector, erase_vector):
@staticmethod
def _update_memory(pre_memory, write_weighting, write_vector, erase_vector):
write_w = tf.expand_dims(write_weighting, 3)
erase_vector = tf.expand_dims(erase_vector, 2)

View File

@ -24,14 +24,15 @@ import yaml
class color_code:
bold = '\033[1m'
underline = '\033[4m'
blue = '\033[94m'
darkcyan = '\033[36m'
green = '\033[92m'
red = '\033[91m'
yellow = '\033[93m'
end = '\033[0m'
def __init__(self):
self.bold = '\033[1m'
self.underline = '\033[4m'
self.blue = '\033[94m'
self.darkcyan = '\033[36m'
self.green = '\033[92m'
self.red = '\033[91m'
self.yellow = '\033[93m'
self.end = '\033[0m'
class Supporter():

View File

@ -41,6 +41,11 @@ class HolisticMultiRNNCell(RNNCell):
"state_is_tuple is not set. State sizes are: %s"
% str([c.state_size for c in self._cells]))
def compute_output_shape(self, input_shape):
sizes = [cell.output_size for cell in self._cells]
return sum(sizes)
@property
def state_size(self):
if self._state_is_tuple:
@ -62,7 +67,7 @@ class HolisticMultiRNNCell(RNNCell):
# presumably does not contain TensorArrays or anything else fancy
return super(HolisticMultiRNNCell, self).zero_state(batch_size, dtype)
def call(self, inputs, state):
def call(self, inputs, state, scope=None):
"""Run this multi-layer cell on inputs, starting from state."""
cur_state_pos = 0
cur_inp = inputs

View File

@ -42,7 +42,8 @@ class WordEmbedding():
embed = tf.nn.embedding_lookup(self.embeddings, word_idx, name='embedding_lookup')
return embed
def initialize_random(self, vocabulary_size, embedding_size, dtype):
@staticmethod
def initialize_random(vocabulary_size, embedding_size, dtype):
return tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0, dtype=dtype)
def initialize_with_glove(self, word_idx_dict, embedding_size, tmp_dir, dtype):