Skip to content

Commit efdfc09

Browse files
committed
feat: 预测阶段滑动窗口
1 parent fd480f7 commit efdfc09

File tree

2 files changed

+114
-65
lines changed

2 files changed

+114
-65
lines changed

experimental/scripts/ke/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def __init__(self, data, tokenizer, max_len=512, exceed_strategy="truncation"):
3232
)
3333

3434
tokens = full_encoding.tokens()
35-
offset_mapping = full_encoding["offset_mapping"].squeeze().tolist()
3635

3736
if len(tokens) <= max_len:
38-
# item['encoding'] = full_encoding
37+
item['encoding'] = full_encoding
3938
self.data.append(item)
4039
continue
41-
40+
41+
offset_mapping = full_encoding["offset_mapping"].squeeze().tolist()
4242
window_size = max_len
4343
stride = window_size // 2
4444

experimental/scripts/ke/predict.py

Lines changed: 111 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,38 +22,85 @@ def load_model(model_path, mode):
2222

2323

2424
def _ner_predict(
25-
text,
26-
model,
27-
tokenizer,
28-
max_len,
29-
device
25+
text,
26+
model,
27+
tokenizer,
28+
max_len,
29+
device,
30+
exceed_strategy
3031
):
3132
model.to(device)
3233

33-
encoding = tokenizer(
34-
text,
35-
add_special_tokens=False,
36-
max_length=max_len,
37-
padding="max_length",
38-
truncation=True,
39-
return_offsets_mapping=True,
40-
return_tensors="pt",
41-
)
42-
43-
input_ids = encoding["input_ids"].to(device)
44-
attention_mask = encoding["attention_mask"].to(device)
45-
token_type_ids = torch.zeros_like(input_ids).to(device)
46-
offset_mapping = encoding["offset_mapping"].squeeze().tolist()
47-
tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze())
34+
pred_label_ids = []
35+
36+
if exceed_strategy == "truncation":
37+
encoding = tokenizer(
38+
text,
39+
add_special_tokens=False,
40+
max_length=max_len,
41+
padding="max_length",
42+
truncation=True,
43+
return_offsets_mapping=True,
44+
return_tensors="pt",
45+
)
4846

49-
with torch.no_grad():
50-
outputs = model(
51-
input_ids=input_ids,
52-
attention_mask=attention_mask,
53-
token_type_ids=token_type_ids,
47+
input_ids = encoding["input_ids"].to(device)
48+
attention_mask = encoding["attention_mask"].to(device)
49+
token_type_ids = torch.zeros_like(input_ids).to(device)
50+
offset_mapping = encoding["offset_mapping"].squeeze().tolist()
51+
tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze())
52+
53+
with torch.no_grad():
54+
outputs = model(
55+
input_ids=input_ids,
56+
attention_mask=attention_mask,
57+
token_type_ids=token_type_ids,
58+
)
59+
pred_label_ids = outputs["pred_label_ids"].cpu().numpy()[0]
60+
61+
else: # exceed_strategy == "sliding_window":
62+
full_encoding = tokenizer(
63+
text,
64+
add_special_tokens=False,
65+
return_offsets_mapping=True,
66+
return_tensors="pt"
5467
)
5568

56-
pred_label_ids = outputs["pred_label_ids"].cpu().numpy()[0]
69+
input_ids = full_encoding["input_ids"].to(device)
70+
tokens = full_encoding.tokens()
71+
attention_mask = full_encoding["attention_mask"].to(device)
72+
token_type_ids = torch.zeros_like(input_ids).to(device)
73+
offset_mapping = full_encoding["offset_mapping"].squeeze().tolist()
74+
75+
if len(tokens) <= max_len:
76+
with torch.no_grad():
77+
outputs = model(
78+
input_ids=input_ids,
79+
attention_mask=attention_mask,
80+
token_type_ids=token_type_ids,
81+
)
82+
pred_label_ids = outputs["pred_label_ids"].cpu().numpy()[0]
83+
84+
else:
85+
window_size = max_len
86+
stride = window_size // 2
87+
88+
start_token_idx = 0
89+
while True:
90+
end_token_idx = min(start_token_idx + window_size, len(tokens))
91+
with torch.no_grad():
92+
window_pred_label_ids = model(
93+
input_ids=input_ids[start_token_idx:end_token_idx],
94+
attention_mask=attention_mask[start_token_idx:end_token_idx],
95+
token_type_ids=token_type_ids[start_token_idx:end_token_idx],
96+
)["pred_label_ids"].cpu().numpy()[0]
97+
if end_token_idx >= len(tokens):
98+
pred_label_ids.extend(window_pred_label_ids) # 最后一个窗口全部保留
99+
break
100+
else:
101+
pred_label_ids.extend(window_pred_label_ids[0: stride]) # 只保留每个窗口的前 stride 部分
102+
103+
start_token_idx += stride
57104

58105
entities = []
59106
char_labels = ["O"] * len(text)
@@ -64,10 +111,10 @@ def _ner_predict(
64111
offset = offset_mapping[i]
65112
if label.startswith("B-"):
66113
char_labels[offset[0]] = label
67-
char_labels[offset[0]+1: offset[1]] = ["I-" + label[2:]] * (offset[1] - offset[0] - 1)
114+
char_labels[offset[0] + 1: offset[1]] = ["I-" + label[2:]] * (offset[1] - offset[0] - 1)
68115
elif label.startswith("I-"):
69116
char_labels[offset[0]: offset[1]] = [label] * (offset[1] - offset[0])
70-
117+
71118
# 从 char_labels 中推断实体
72119
i = 0
73120
while i < len(char_labels):
@@ -76,31 +123,31 @@ def _ner_predict(
76123
start = i
77124
i += 1
78125
while i < len(char_labels) and (
79-
char_labels[i] == f"I-{entity_type}"
80-
or (char_labels[i] == f"O" and tokens[i].startswith("##"))
126+
char_labels[i] == f"I-{entity_type}"
127+
or (char_labels[i] == f"O" and tokens[i].startswith("##"))
81128
):
82129
i += 1
83130
end = i
84-
entities.append({"start": start,
85-
"end": end,
86-
"type": entity_type,
87-
"text": text[start:end]}) # 缺少 id
131+
entities.append({"start": start,
132+
"end": end,
133+
"type": entity_type,
134+
"text": text[start:end]}) # 缺少 id
88135
else:
89136
i += 1
90137
return entities
91138

92139

93140
def _re_predict(
94-
text,
95-
e1,
96-
e2,
97-
model,
98-
tokenizer,
99-
max_len,
100-
device
141+
text,
142+
e1,
143+
e2,
144+
model,
145+
tokenizer,
146+
max_len,
147+
device
101148
):
102149
model.to(device)
103-
150+
104151
encoding = tokenizer(
105152
text,
106153
add_special_tokens=False,
@@ -114,10 +161,10 @@ def _re_predict(
114161
attention_mask = encoding["attention_mask"].to(device)
115162
token_type_ids = torch.zeros_like(input_ids).to(device)
116163
offset_mapping = encoding["offset_mapping"].squeeze().tolist()
117-
164+
118165
e1_mask = _create_entity_mask(input_ids, offset_mapping, e1[0], e1[1])
119166
e2_mask = _create_entity_mask(input_ids, offset_mapping, e2[0], e2[1])
120-
167+
121168
with torch.no_grad():
122169
outputs = model(
123170
input_ids=input_ids,
@@ -126,24 +173,25 @@ def _re_predict(
126173
e1_mask=e1_mask,
127174
e2_mask=e2_mask
128175
)
129-
176+
130177
logits = outputs["logits"]
131178
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()[0]
132-
179+
133180
pred_idx = logits.argmax(dim=1).item()
134181
relation = id2relation[pred_idx]
135182
probability = probs[pred_idx]
136-
return {"source": text[e1[0]:e1[1]],
137-
"target": text[e2[0]:e2[1]],
138-
"type": relation,
183+
return {"source": text[e1[0]:e1[1]],
184+
"target": text[e2[0]:e2[1]],
185+
"type": relation,
139186
"probability": float(probability)} # 缺少 source_id 和 target_id, 多 probability
140187

141188

142189
def ner_predict(
143-
text: str,
144-
model_path: str = "experimental/scripts/ke/checkpoints/ner/final_model",
145-
max_len: int = 512,
146-
device: str = "cuda" if torch.cuda.is_available() else "cpu",
190+
text: str,
191+
model_path: str = "experimental/scripts/ke/checkpoints/ner/final_model",
192+
max_len: int = 512,
193+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
194+
exceed_strategy: str = "truncation"
147195
) -> list[dict]:
148196
"""
149197
使用模型进行实体预测
@@ -158,16 +206,17 @@ def ner_predict(
158206
list[dict]: 预测结果列表
159207
"""
160208
model, tokenizer = load_model(model_path, "ner")
161-
return _ner_predict(text, model, tokenizer, max_len, device)
162-
209+
return _ner_predict(text, model, tokenizer, max_len, device, exceed_strategy)
210+
163211

164212
def re_predict(
165-
text: str,
166-
e1_range: tuple[int, int],
167-
e2_range: tuple[int, int],
168-
model_path: str = "experimental/scripts/ke/checkpoints/re/final_model",
169-
max_len: int = 512,
170-
device: str = "cuda" if torch.cuda.is_available() else "cpu",
213+
text: str,
214+
e1_range: tuple[int, int],
215+
e2_range: tuple[int, int],
216+
model_path: str = "experimental/scripts/ke/checkpoints/re/final_model",
217+
max_len: int = 512,
218+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
219+
exceed_strategy: str = "truncation"
171220
) -> dict:
172221
"""
173222
使用模型进行关系预测
@@ -184,4 +233,4 @@ def re_predict(
184233
dict: 预测结果
185234
"""
186235
model, tokenizer = load_model(model_path, "re")
187-
return _re_predict(text, e1_range, e2_range, model, tokenizer, max_len, device)
236+
return _re_predict(text, e1_range, e2_range, model, tokenizer, max_len, device, exceed_strategy)

0 commit comments

Comments
 (0)