Skip to content

Commit ff68003

Browse files
committed
Initial code release
1 parent 252bf89 commit ff68003

File tree

9 files changed

+3025
-0
lines changed

9 files changed

+3025
-0
lines changed

encoding.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
from pathlib import Path
2+
from typing import Dict, List, Tuple
3+
4+
from fire import Fire
5+
from pydantic import BaseModel
6+
from tqdm import tqdm
7+
from transformers import AutoTokenizer
8+
9+
from transformer_base import run_summarization
10+
from utils import RelationData, RelationSentence
11+
12+
13+
class Encoder(BaseModel):
14+
def encode_x(self, x: str) -> str:
15+
raise NotImplementedError
16+
17+
def encode(self, sent: RelationSentence) -> Tuple[str, str]:
18+
raise NotImplementedError
19+
20+
def decode(self, x: str, y: str) -> RelationSentence:
21+
raise NotImplementedError
22+
23+
def decode_x(self, x: str) -> str:
24+
raise NotImplementedError
25+
26+
def safe_decode(self, x: str, y: str) -> RelationSentence:
27+
text = self.decode_x(x)
28+
try:
29+
s = self.decode(x=x, y=y)
30+
except Exception as e:
31+
s = RelationSentence(
32+
tokens=text.split(), head=[], tail=[], label="", error=str(e), raw=y
33+
)
34+
return s
35+
36+
def encode_to_line(self, sent: RelationSentence) -> str:
37+
raise NotImplementedError
38+
39+
def decode_from_line(self, line: str) -> RelationSentence:
40+
raise NotImplementedError
41+
42+
def parse_line(self, line: str) -> Tuple[str, str]:
43+
raise NotImplementedError
44+
45+
46+
class GenerateEncoder(Encoder):
47+
def encode_x(self, r: str) -> str:
48+
return f"Relation : {r} ."
49+
50+
def decode_x(self, text: str) -> str:
51+
return text.split("Relation : ")[-1][:-2]
52+
53+
def encode_triplet(self, sent: RelationSentence) -> str:
54+
s, r, o = sent.as_tuple()
55+
return f"Context : {sent.text} Head Entity : {s} , Tail Entity : {o} ."
56+
57+
def decode_triplet(self, text: str, label: str) -> RelationSentence:
58+
front, back = text.split(" Head Entity : ")
59+
_, context = front.split("Context : ")
60+
head, back = back.split(" , Tail Entity : ")
61+
tail = back[:-2]
62+
return RelationSentence.from_spans(context, head, tail, label)
63+
64+
def encode_y(self, sent: RelationSentence) -> str:
65+
return self.encode_x(sent.label) + " " + self.encode_triplet(sent)
66+
67+
def decode_y(self, text: str, label: str) -> RelationSentence:
68+
del label
69+
front, back = text.split(" . Context : ")
70+
label = self.decode_x(front + " .")
71+
return self.decode_triplet("Context : " + back, label)
72+
73+
def decode(self, x: str, y: str) -> RelationSentence:
74+
r = self.decode_x(x)
75+
sent = self.decode_y(y, r)
76+
return sent
77+
78+
def encode(self, sent: RelationSentence) -> Tuple[str, str]:
79+
x = self.encode_x(sent.label)
80+
y = self.encode_y(sent)
81+
return x, y
82+
83+
def decode_from_line(self, line: str) -> RelationSentence:
84+
x, y = self.parse_line(line)
85+
return self.decode(x, y)
86+
87+
def encode_to_line(self, sent: RelationSentence) -> str:
88+
x, y = self.encode(sent)
89+
return y + "\n"
90+
91+
def parse_line(self, line: str) -> Tuple[str, str]:
92+
return "", line.strip()
93+
94+
95+
class ExtractEncoder(Encoder):
96+
def encode_x(self, text: str) -> str:
97+
return f"Context : {text}"
98+
99+
def decode_x(self, x: str) -> str:
100+
return x.split("Context : ")[-1]
101+
102+
def encode_y(self, sent: RelationSentence) -> str:
103+
s, r, o = sent.as_tuple()
104+
return f"Head Entity : {s} , Tail Entity : {o} , Relation : {r} ."
105+
106+
def decode_y(self, x: str, y: str) -> RelationSentence:
107+
context = self.decode_x(x)
108+
front, label = y.split(" , Relation : ")
109+
label = label[:-2]
110+
front, tail = front.split(" , Tail Entity : ")
111+
_, head = front.split("Head Entity : ")
112+
return RelationSentence.from_spans(context, head, tail, label)
113+
114+
def encode_entity_prompt(self, head: str, tail: str) -> str:
115+
return f"Head Entity : {head} , Tail Entity : {tail} , Relation :"
116+
117+
def encode(self, sent: RelationSentence) -> Tuple[str, str]:
118+
x = self.encode_x(sent.text)
119+
y = self.encode_y(sent)
120+
return x, y
121+
122+
def decode(self, x: str, y: str) -> RelationSentence:
123+
return self.decode_y(x, y)
124+
125+
def encode_to_line(self, sent: RelationSentence) -> str:
126+
x, y = self.encode(sent)
127+
return run_summarization.encode_to_line(x, y)
128+
129+
def decode_from_line(self, line: str) -> RelationSentence:
130+
x, y = self.parse_line(line)
131+
return self.decode(x, y)
132+
133+
def parse_line(self, line: str) -> Tuple[str, str]:
134+
return run_summarization.decode_from_line(line)
135+
136+
137+
def test_encoders(
138+
paths: List[str] = [
139+
"outputs/data/zsl/wiki/unseen_5_seed_0/train.jsonl",
140+
"outputs/data/zsl/fewrel/unseen_5_seed_0/train.jsonl",
141+
],
142+
print_limit: int = 4,
143+
encoder_names: List[str] = ["generate", "extract"],
144+
limit: int = 1000,
145+
):
146+
encoders = {k: select_encoder(k) for k in encoder_names}
147+
148+
for p in paths:
149+
data = RelationData.load(Path(p))
150+
_, data = data.train_test_split(min(limit, len(data.sents)), random_seed=0)
151+
152+
for name, e in tqdm(list(encoders.items())):
153+
num_fail = 0
154+
print(dict(name=name, p=p))
155+
for s in data.sents:
156+
encoded = e.encode_to_line(s)
157+
x, y = e.parse_line(encoded)
158+
decoded: RelationSentence = e.safe_decode(x, y)
159+
160+
if decoded.as_tuple() != s.as_tuple():
161+
if num_fail < print_limit:
162+
print(dict(gold=s.as_tuple(), text=s.text))
163+
print(dict(pred=decoded.as_tuple(), text=decoded.text))
164+
print(dict(x=x, y=y, e=decoded.error))
165+
print()
166+
num_fail += 1
167+
168+
print(dict(success_rate=1 - (num_fail / len(data.sents))))
169+
print("#" * 80)
170+
171+
172+
def select_encoder(name: str) -> Encoder:
173+
mapping: Dict[str, Encoder] = dict(
174+
extract=ExtractEncoder(),
175+
generate=GenerateEncoder(),
176+
)
177+
encoder = mapping[name]
178+
return encoder
179+
180+
181+
def test_entity_prompts(
182+
path: str = "outputs/data/zsl/wiki/unseen_10_seed_0/test.jsonl", limit: int = 100
183+
):
184+
def tokenize(text: str, tok) -> List[str]:
185+
return tok.convert_ids_to_tokens(tok(text, add_special_tokens=False).input_ids)
186+
187+
data = RelationData.load(Path(path))
188+
e = ExtractEncoder()
189+
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
190+
print(tokenizer)
191+
for i, s in enumerate(tqdm(data.sents[:limit])):
192+
head, label, tail = s.as_tuple()
193+
x, y = e.encode(s)
194+
prompt = e.encode_entity_prompt(head, tail)
195+
tokens_y = tokenize(y, tokenizer)
196+
tokens_prompt = tokenize(prompt, tokenizer)
197+
assert tokens_y[: len(tokens_prompt)] == tokens_prompt
198+
if i < 3:
199+
print(tokens_y)
200+
201+
202+
if __name__ == "__main__":
203+
Fire()

generation.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
from typing import Dict, List, Optional, Tuple
2+
3+
import torch
4+
from fire import Fire
5+
from torch import Tensor
6+
from transformers import PreTrainedModel, PreTrainedTokenizerFast
7+
8+
from encoding import ExtractEncoder
9+
from utils import DynamicModel, RelationSentence, find_sublist_index
10+
11+
12+
class TextGenerator(DynamicModel):
13+
model: PreTrainedModel
14+
tokenizer: PreTrainedTokenizerFast
15+
scores: Optional[List[Tensor]] = None
16+
max_length: int
17+
18+
def tokenize(self, texts: List[str], **kwargs):
19+
return self.tokenizer(
20+
texts,
21+
padding=True,
22+
truncation=True,
23+
max_length=self.max_length,
24+
return_tensors="pt",
25+
**kwargs,
26+
).to(self.model.device)
27+
28+
def run(
29+
self,
30+
texts: List[str],
31+
do_sample=True,
32+
top_k=50,
33+
temperature=1.0,
34+
num_return: int = 4,
35+
prompt: Optional[str] = None,
36+
prompt_ids: Optional[List[int]] = None,
37+
multi_prompt_ids: Optional[List[List[int]]] = None,
38+
decoder_input_ids: Optional[Tensor] = None,
39+
save_scores: bool = False,
40+
**kwargs,
41+
) -> List[str]:
42+
# https://huggingface.co/transformers/v4.7.0/main_classes/model.html#generation
43+
tok = self.tokenizer
44+
eos, bos = tok.eos_token_id, tok.bos_token_id
45+
46+
if prompt is not None:
47+
prompt_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids
48+
if prompt_ids is not None:
49+
prompt_ids = [eos, bos] + prompt_ids
50+
decoder_input_ids = torch.tensor([prompt_ids])
51+
if multi_prompt_ids is not None:
52+
assert len(texts) == len(multi_prompt_ids)
53+
multi_prompt_ids = [[eos, bos] + lst for lst in multi_prompt_ids]
54+
decoder_input_ids = torch.tensor(multi_prompt_ids)
55+
if decoder_input_ids is not None:
56+
kwargs.update(decoder_input_ids=decoder_input_ids.to(self.model.device))
57+
58+
outputs = self.model.generate(
59+
**self.tokenize(texts),
60+
do_sample=do_sample,
61+
top_k=top_k,
62+
temperature=temperature,
63+
num_return_sequences=num_return,
64+
return_dict_in_generate=True,
65+
output_scores=save_scores,
66+
max_length=self.max_length,
67+
**kwargs,
68+
)
69+
70+
self.scores = None
71+
if save_scores:
72+
self.scores = [_ for _ in torch.stack(outputs.scores, 1).cpu()]
73+
return self.decode(outputs.sequences)
74+
75+
def decode(self, outputs) -> List[str]:
76+
tok = self.tokenizer
77+
texts = tok.batch_decode(
78+
outputs, skip_special_tokens=False, clean_up_tokenization_spaces=False
79+
)
80+
81+
# Manually remove <bos><eos><pad> in case we have custom special tokens
82+
special_tokens = [tok.eos_token, tok.bos_token, tok.pad_token]
83+
for i, t in enumerate(texts):
84+
for token in special_tokens:
85+
t = t.replace(token, "")
86+
texts[i] = t
87+
return texts
88+
89+
90+
class LabelConstraint:
91+
def __init__(
92+
self,
93+
labels: List[str],
94+
tokenizer: PreTrainedTokenizerFast,
95+
prefix: str = " Relation :",
96+
):
97+
self.prefix: List[int] = tokenizer(prefix, add_special_tokens=False).input_ids
98+
self.label_map: Dict[int, str] = {
99+
tokenizer(" " + x, add_special_tokens=False).input_ids[0]: x for x in labels
100+
}
101+
self.tokenizer = tokenizer
102+
103+
def run(self, triplet: RelationSentence, scores: Tensor) -> RelationSentence:
104+
triplet = triplet.copy(deep=True)
105+
assert scores.ndim == 2
106+
token_ids = scores.argmax(dim=-1).int().tolist()
107+
i = find_sublist_index(token_ids, self.prefix)
108+
if i == -1:
109+
return triplet
110+
111+
position = i + len(self.prefix)
112+
best = ""
113+
best_score = -1e9
114+
for j, label in self.label_map.items():
115+
score = scores[position, j].item()
116+
if score > best_score:
117+
best = label
118+
best_score = score
119+
120+
if triplet.label in self.label_map.values():
121+
assert best == triplet.label
122+
123+
assert len(best) > 0
124+
triplet.label = best
125+
triplet.score = best_score
126+
return triplet
127+
128+
129+
class TripletSearchDecoder(DynamicModel):
130+
gen: TextGenerator
131+
constraint: LabelConstraint
132+
encoder: ExtractEncoder
133+
top_k: int = 4
134+
135+
def generate(self, text: str, **kwargs) -> Tuple[str, Tensor]:
136+
outputs = self.gen.run(
137+
[text],
138+
do_sample=False,
139+
num_return=1,
140+
num_beams=1,
141+
save_scores=True,
142+
**kwargs,
143+
)
144+
145+
assert len(outputs) == 1
146+
assert self.gen.scores is not None
147+
scores = torch.log_softmax(self.gen.scores[0], dim=-1)
148+
assert scores.ndim == 2
149+
return outputs[0], scores
150+
151+
def find_prefix_end(self, token_ids: List[str], prefix: str) -> int:
152+
prefix_ids = self.gen.tokenizer(prefix, add_special_tokens=False).input_ids
153+
i = find_sublist_index(token_ids, prefix_ids)
154+
position = i + len(prefix_ids)
155+
return position
156+
157+
def branch(
158+
self, text: str, prefix: str, prompt: Optional[str] = None, **kwargs
159+
) -> List[Tuple[str, float]]:
160+
_, scores = self.generate(text, prompt=prompt, **kwargs)
161+
token_ids = scores.argmax(dim=-1).int().tolist()
162+
i = self.find_prefix_end(token_ids, prefix)
163+
164+
pairs = []
165+
for j in torch.argsort(scores[i])[-self.top_k :]:
166+
p = (prompt or "") + self.gen.decode([token_ids[:i] + [j]])[0]
167+
pairs.append((p, scores[i, j].item()))
168+
169+
return pairs
170+
171+
def run(self, text: str) -> List[RelationSentence]:
172+
x = self.encoder.encode_x(text)
173+
outputs = []
174+
175+
for prompt_a, score_a in self.branch(x, prefix="Head Entity :"):
176+
for prompt_b, score_b in self.branch(
177+
x, prefix=" Tail Entity :", prompt=prompt_a
178+
):
179+
output, scores = self.generate(x, prompt=prompt_b)
180+
token_ids = token_ids = scores.argmax(dim=-1).int().tolist()
181+
i = self.find_prefix_end(token_ids, prefix=" Relation :")
182+
score_c = max(scores[i].tolist())
183+
s = self.encoder.safe_decode(x=x, y=output)
184+
s = self.constraint.run(s, scores)
185+
# score_c = s.score # From LabelConstraint
186+
s.score = (score_a + score_b + score_c) / 3
187+
outputs.append(s)
188+
189+
return outputs
190+
191+
192+
if __name__ == "__main__":
193+
Fire()

0 commit comments

Comments
 (0)