-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsrl.py
More file actions
54 lines (44 loc) · 1.61 KB
/
srl.py
File metadata and controls
54 lines (44 loc) · 1.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#import jsonlines
import time
from allennlp.models.archival import load_archive
from allennlp.predictors import Predictor
import sys
import os
import argparse
import torch
import parser
torch.set_num_threads(1)
class PretrainedModel:
"""
A pretrained model is determined by both an archive file
(representing the trained model)
and a choice of predictor.
"""
def __init__(self, archive_file: str, predictor_name: str) -> None:
self.archive_file = archive_file
self.predictor_name = predictor_name
def predictor(self) -> Predictor:
archive = load_archive(self.archive_file)
return Predictor.from_archive(archive, self.predictor_name)
class AllenSRL:
def __init__(self):
model = PretrainedModel('https://s3-us-west-2.amazonaws.com/allennlp/models/srl-model-2018.05.25.tar.gz',
'semantic-role-labeling')
self.predictor = model.predictor()
#self.predictor._model = self.predictor._model.cuda()
#self.output_path = output_path
def predict(self,tokens):
prediction = self.predictor.predict_tokenized(tokens)
#predictionarray = [prediction]
#result = parser.extract_timex(predictionarray)
print(prediction)
tags = prediction['verbs'][0]['tags']
words = prediction['words']
tempargs = ""
for x in range(len(tags)):
if tags[x] == 'I-ARGM-TMP':
tempargs = tempargs + " " + words[x]
if tags[x] == 'B-ARGM-TMP':
tempargs = words[x]
srl = AllenSRL()
srl.predict("I ate dinner, 02/04/2002".split(" "))