Create run_detect.py
This commit is contained in:
parent
9e3b1cda22
commit
b66ec51c46
@ -0,0 +1,92 @@
|
||||
import json
|
||||
import os
|
||||
from MICNN import *
|
||||
from Hi_Attn import *
|
||||
import argparse
|
||||
|
||||
def _load_data(params, args):
|
||||
if os.path.isfile("Data/" + args.dataset + "/data.pkl"):
|
||||
print("Loading data.pkl for analyzing", args.dataset)
|
||||
data = pickle.load(open("Data/" + args.dataset + "/data.pkl", "rb"))
|
||||
else:
|
||||
print("data.pkl not found. Generating bags ..")
|
||||
data = GenerateTrain(params["batch_size"], args.dataset)
|
||||
|
||||
print("Loading train and test sets")
|
||||
return data
|
||||
|
||||
def _load_unlabeled(params, args, vocabs):
|
||||
if args.goal != "train":
|
||||
if os.path.isfile("Data/" + args.dataset + "/patch.pkl"):
|
||||
print("Loading unlabeled patch batches")
|
||||
unlabeled_batches = pickle.load(open("Data/" + args.dataset + "/patch.pkl", "rb"))
|
||||
else:
|
||||
print("Batching unlabeled patch data")
|
||||
unlabeled_batches = GenerateUnlabeled(vocabs, params["batch_size"], args.dataset)
|
||||
else:
|
||||
unlabeled_batches = []
|
||||
return unlabeled_batches
|
||||
|
||||
def _detect(params, data):
|
||||
train_batches, dev_batches, test_batches, vocabs, embedding = data
|
||||
unlabeled_batches = _load_unlabeled(params, args, vocabs)
|
||||
|
||||
if args.model == "MICNN":
|
||||
model = MICNN(params, vocabs, embedding)
|
||||
elif args.model == "ATTN":
|
||||
model = Hi_Attn(params, vocabs, embedding)
|
||||
|
||||
model.build()
|
||||
|
||||
if args.goal == "train":
|
||||
train_sent, dev_sent, test_sent = model.run_model(
|
||||
BatchIt(train_batches, params["batch_size"], vocabs),
|
||||
BatchIt(dev_batches, params["batch_size"], vocabs),
|
||||
BatchIt(test_batches, params["batch_size"], vocabs))
|
||||
best_sent = [train_sent, test_sent, dev_sent]
|
||||
if train_sent:
|
||||
for j, articles in enumerate([train_batches, test_batches, dev_batches]):
|
||||
for i in range(len(articles)):
|
||||
articles[i]["best_sent"] = best_sent[j][i]
|
||||
|
||||
pickle.dump((train_batches, dev_batches, test_batches, vocabs, embedding),
|
||||
open("Data/" + args.dataset + "/data.pkl", "wb"))
|
||||
|
||||
elif args.goal == "predict":
|
||||
hate_pred, indice_pred = model.predict(unlabeled_batches)
|
||||
pickle.dump(hate_pred, open("Data/" + args.dataset + "/labels.pkl", "wb"))
|
||||
hate_batches = list()
|
||||
for i in range(len(unlabeled_batches)):
|
||||
if hate_pred[i] == 1:
|
||||
if args.model == "MICNN":
|
||||
hate_batches.append(unlabeled_batches[i] + (hate_pred[i], indice_pred[i], i))
|
||||
else:
|
||||
hate_batches.append(unlabeled_batches[i] + (hate_pred[i], i))
|
||||
pickle.dump(hate_batches, open("Data/" + args.dataset + "/predict.pkl", "wb"))
|
||||
|
||||
elif args.goal == "active":
|
||||
probabilities = model.predict(unlabeled_batches, active_learning=True)
|
||||
pickle.dump(probabilities, open("probability.pkl", "wb"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--model", help = "Model is either MICNN or ATTN")
|
||||
parser.add_argument("--goal", help = "Goal can be either train or predict")
|
||||
parser.add_argument("--dataset", help="Dataset is either hate, homicide or kidnap")
|
||||
parser.add_argument("--params", help = "Path to the params file, a json file "
|
||||
"that contains model parameters")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
params = json.load(open(args.params, "r"))
|
||||
except Exception:
|
||||
print("Error in reading from the provided path, loading the default"
|
||||
"parameters instead")
|
||||
params = json.load(open("params.json", "r"))
|
||||
|
||||
params["dataset"] = args.dataset
|
||||
data = _load_data(params, args)
|
||||
_detect(params, data)
|
Loading…
Reference in New Issue
Block a user