diff --git a/adnc/data/tasks/cnn_rc.py b/adnc/data/tasks/cnn_rc.py index 52e3fce..a992b0e 100644 --- a/adnc/data/tasks/cnn_rc.py +++ b/adnc/data/tasks/cnn_rc.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import os import operator import pathlib import tarfile +import pickle from collections import Counter, OrderedDict from urllib.request import Request, urlopen @@ -249,6 +251,15 @@ class ReadingComprehension(): return x_word, prediction_decode + def save_dictionary(self, dir): + with open(os.path.join(dir, 'word_idx_dict.pkl'), 'wb') as outfile: + pickle.dump(self.word_idx_dict, outfile) + + def load_dictionary(self, dir): + with open(os.path.join(dir, 'word_idx_dict.pkl'), 'rb') as outfile: + self.word_idx_dict = pickle.load(outfile) + self.idx_word_dict = {v:k for k,v in self.word_idx_dict.items()} + @property def vocabulary_size(self): return self.word_idx_dict.__len__()