AttentivePathRanking/main/data/MIDFreebase15kReader.py
2020-07-31 17:32:38 -04:00

323 lines
14 KiB
Python

import os
import glob
from gensim.models import KeyedVectors
import pickle
import matplotlib.pyplot as plt
import numpy as np
import json
import wikidata.client
# We use all triples in train and test set to write one complete set of relation instances. The split of the
# original data will not be used. A new split will be constructed later.
class MIDFreebase15kReader:
"""
This class helps read raw FB15k-237 data and rewrite the data in the format necessary for the experiments.
Data files involved:
- input: ``/FB15k-237/train.txt``, ``/FB15k-237/test.txt``, ``/FB15k-237/valid.txt``, ``/FB15k-237/fb2w.nt``,
``/type_information/entity2type.txt``, ``/type_information/relation_specific.txt``
- output: ``edges.txt``, ``domains.tsv``, ``ranges.tsv``, ``synonym2vec.pkl``, ``entity2types.json``,
``mid2name.pkl``
.. note::
We extract all triples in the original train and test set. The split of the original data is suitable for
testing embedding methods and will not be used.
.. note::
The entities in the raw data are MIDs (Machine Identifiers), which are numbers.
"""
def __init__(self, dir, filter=False, word2vec_filename=""):
"""
Initialize a reader object.
:param dir: the root of the data folder, where the ``/FB15k-237`` folder and the ``/type_information`` folder
will be found, and all output files will be written to.
:param filter: if set to True, entities with no corresponding word embeddings will be removed.
:param word2vec_filename: the word embeddings for the entities.
"""
self.dir = dir
if not os.path.exists(dir):
raise Exception(dir, "not exist")
self.train_instances = []
self.relations = set()
self.mids = set()
self.relation_domain = {}
self.relation_range = {}
self.filter = filter
self.word2vec_filename = word2vec_filename
self.synonym2vec = {}
self.client = wikidata.client.Client()
def read_data(self):
"""
This function processes the raw data. Specifically, this function does the following few things:
- It reads all relation instances in ``/FB15k-237`` and type information in ``/type_information``
for entities and relations.
- It removes relations and entities without type information.
- It retrieves word embeddings for all entities and saves the embeddings in ``synonym2vec.pkl``.
- It prints out relations with more than 3000 instances to help us select relations to test with.
- It finds the 7 most frequently occurring types for each entity and save them in ``entity2types.json``.
"""
# 1. collect all relevant files
files = glob.glob(os.path.join(self.dir, "FB15k-237/*.txt"))
for file in files:
if "train" in file:
train_filename = file
elif "test" in file:
test_filename = file
elif "valid" in file:
dev_filename = file
files = glob.glob(os.path.join(self.dir, "type_information/*.txt"))
for file in files:
if "entity2type" in file:
entity2type_filename = file
elif "relation_specific" in file:
relation_domain_range_filename = file
# 2. read entity type and relation domain and range information
mid2types = {}
with open(entity2type_filename) as fh:
for line in fh:
line = line.strip()
if len(line) == 0:
continue
contents = line.split("\t")
mid = contents[0]
# important: bc "/" in relations maybe confused with "/" as dir
mid = mid.replace("/", "|")
types = contents[1:]
mid2types[mid] = types
relation_domain = {}
relation_range = {}
with open(relation_domain_range_filename) as fh:
for line in fh:
line = line.strip()
if len(line) == 0:
continue
relation, domain, range = line.split("\t")
relation = relation.replace("/", "|")
domain = domain.replace("/", "|")
range = range.replace("/", "|")
relation_domain[relation] = domain
relation_range[relation] = range
# 3. get all mids and relations in the data set
mids = set()
relations = set()
for filename in [train_filename, test_filename, dev_filename]:
with open(filename, "r") as fh:
for line in fh:
line = line.strip()
if len(line) == 0:
continue
source, relation, target = line.split("\t")
mids.add(source.replace("/", "|"))
mids.add(target.replace("/", "|"))
relations.add(relation.replace("/", "|"))
# 4. remove mids with no type information, remove relations with no domain and range information
# Note: No need to filter out entities without types. We also have done this for WN18RR dataset
filtered_mids = set()
for mid in mids:
if mid in mid2types:
filtered_mids.add(mid)
print("mids before and after filtered by type information", len(mids), len(filtered_mids))
filtered_relations = set()
for relation in relations:
if relation in relation_domain:
filtered_relations.add(relation)
print("relations before and after filtered by type information", len(relations), len(filtered_relations))
# 5. Load word2vec and remove mids with no word2vec mapping. The mapping will be used for context-aware
mids = filtered_mids
filtered_mids = set()
if self.filter:
if os.path.exists(os.path.join(self.dir, "synonym2vec.pkl")):
print("synonym2vec pickle file found: ", os.path.join(self.dir, "synonym2vec.pkl"))
self.synonym2vec = pickle.load(open(os.path.join(self.dir, "synonym2vec.pkl"), "rb"))
else:
if os.path.exists(self.word2vec_filename):
# Important: only native word2vec file needs binary flag to be true
word2vec_model = KeyedVectors.load_word2vec_format(self.word2vec_filename, binary=True)
else:
raise Exception("word2vec file not found")
print("word2vec loaded")
for mid in mids:
try:
self.synonym2vec[mid] = word2vec_model.get_vector(mid.replace("|", "/"))
filtered_mids.add(mid)
except KeyError:
pass
print(len(mids) - len(filtered_mids), "MIDs have no matching word2vec entry")
print(len(filtered_mids), "MIDs have.")
print("Saving synonym2vec to pickle file:", os.path.join(self.dir, "synonym2vec.pkl"))
pickle.dump(self.synonym2vec, open(os.path.join(self.dir, "synonym2vec.pkl"), "wb"))
self.relations = filtered_relations
self.mids = filtered_mids
self.relation_domain = relation_domain
self.relation_range = relation_range
# 6. Read data.
# keep track of instances of relations
relation_to_instance_count = {}
for filename in [train_filename, test_filename, dev_filename]:
with open(filename, "r") as fh:
for line in fh:
line = line.strip()
if len(line) == 0:
continue
source, relation, target = line.split("\t")
source = source.replace("/", "|")
relation = relation.replace("/", "|")
target = target.replace("/", "|")
if source in filtered_mids and target in filtered_mids and relation in filtered_relations:
self.train_instances.append((source, relation, target))
if relation not in relation_to_instance_count:
relation_to_instance_count[relation] = 0
relation_to_instance_count[relation] += 1
print("There are", len(self.train_instances), "triplets in data")
print("There are", len(relation_to_instance_count), "relations")
print("There are", len(self.mids), "entities")
print("Relations with more than 3000 instances")
for relation in relation_to_instance_count:
if relation_to_instance_count[relation] > 1000:
print("\"" + relation + "\",", end=' ')
# 7. Write entity2types.json
# (1). find occurrences of mid
type2count = {}
for mid in mid2types:
types = mid2types[mid]
for type in types:
if type not in type2count:
type2count[type] = 0
type2count[type] += 1
# (2). only use 7 most occurring types
types_used = set()
entity2types = {}
for mid in filtered_mids:
all_types = mid2types[mid]
all_types_with_count = []
for type in all_types:
all_types_with_count.append((type, type2count[type]))
all_types_with_count.sort(key=lambda x: x[1], reverse=True)
most_occurring_types = [type[0] for type in all_types_with_count][:7]
# we arrange the types from least occurring to most occurring
most_occurring_types.reverse()
# print(most_occurring_types)
entity2types[mid] = most_occurring_types
types_used.update(most_occurring_types)
print("Display entity types statistics, after filtering")
total_length = []
for node in entity2types:
total_length.append(len(entity2types[node]))
plt.hist(np.array(total_length), bins=14)
plt.show()
print("average number of types is", sum(total_length) / len(total_length))
print("max number of types for an entity is", max(total_length))
print("Number of types:", len(types_used))
print("Writing entity2types to file")
with open(os.path.join(self.dir, "entity2types.json"), "w+") as fh:
json.dump(entity2types, fh)
def get_mid_to_name(self):
"""
This function retrieves names of entities from their definitions in ``/FB15k-237/fb2w.nt`` and
saves to ``mid2name.pkl``.
"""
# 1. collect all relevant files
files = glob.glob(os.path.join(self.dir, "FB15k-237/*.txt"))
for file in files:
if "train" in file:
train_filename = file
elif "test" in file:
test_filename = file
elif "valid" in file:
dev_filename = file
mids = set()
for filename in [train_filename, test_filename, dev_filename]:
with open(filename, "r") as fh:
for line in fh:
line = line.strip()
if len(line) == 0:
continue
source, relation, target = line.split("\t")
mids.add(source.replace("/", "|"))
mids.add(target.replace("/", "|"))
mid2name = {}
count = 0
definitions_filename = os.path.join(self.dir, "FB15k-237/fb2w.nt")
# 1. definitions
with open(definitions_filename, "r") as fh:
for line in fh:
line = line.strip()
if "\t" not in line:
continue
line = line[:-1].strip()
mid, _, wikidata_url = line.split("\t")
mid = "/" + mid[1:-1].split("/")[-1].replace(".", "/")
mid = mid.replace("/", "|")
# only get names of mids that are in the data
if mid not in mids:
continue
wikid = wikidata_url[1:-1].split("/")[-1]
try:
entity_name = str(self.client.get(wikid).label)
entity_name = "_".join(entity_name.lower().split(" "))
mid2name[mid] = entity_name
print("{}/{}: {} <---> {}".format(count, len(mids), mid, entity_name))
count += 1
except:
print("Cannot retrive name for {}".format(mid))
mid_to_name_filename = os.path.join(self.dir, "mid2name.pkl")
pickle.dump(mid2name, open(mid_to_name_filename, "wb"))
def write_relation_domain_and_ranges(self):
"""
This function writes relations' domains and ranges in ``domains.tsv`` and ``ranges.tsv``.
"""
domains_filename = os.path.join(self.dir, "domains.tsv")
ranges_filename = os.path.join(self.dir, "ranges.tsv")
with open(domains_filename, "w+") as fd:
with open(ranges_filename, "w+") as fr:
for rel in self.relations:
fd.write(rel + "\t" + self.relation_domain[rel] + "\n")
fr.write(rel + "\t" + self.relation_range[rel] + "\n")
def write_edges(self):
"""
This function writes all processed relation instances in ``edges.txt``.
"""
edges_filename = os.path.join(self.dir, "edges.txt")
with open(edges_filename, "w+") as fe:
for subj, rel, obj in self.train_instances:
fe.write(subj + "\t" + rel + "\t" + obj + "\n")