@@ -22,38 +22,85 @@ def load_model(model_path, mode):
22
22
23
23
24
24
def _ner_predict (
25
- text ,
26
- model ,
27
- tokenizer ,
28
- max_len ,
29
- device
25
+ text ,
26
+ model ,
27
+ tokenizer ,
28
+ max_len ,
29
+ device ,
30
+ exceed_strategy
30
31
):
31
32
model .to (device )
32
33
33
- encoding = tokenizer (
34
- text ,
35
- add_special_tokens = False ,
36
- max_length = max_len ,
37
- padding = "max_length" ,
38
- truncation = True ,
39
- return_offsets_mapping = True ,
40
- return_tensors = "pt" ,
41
- )
42
-
43
- input_ids = encoding ["input_ids" ].to (device )
44
- attention_mask = encoding ["attention_mask" ].to (device )
45
- token_type_ids = torch .zeros_like (input_ids ).to (device )
46
- offset_mapping = encoding ["offset_mapping" ].squeeze ().tolist ()
47
- tokens = tokenizer .convert_ids_to_tokens (input_ids .squeeze ())
34
+ pred_label_ids = []
35
+
36
+ if exceed_strategy == "truncation" :
37
+ encoding = tokenizer (
38
+ text ,
39
+ add_special_tokens = False ,
40
+ max_length = max_len ,
41
+ padding = "max_length" ,
42
+ truncation = True ,
43
+ return_offsets_mapping = True ,
44
+ return_tensors = "pt" ,
45
+ )
48
46
49
- with torch .no_grad ():
50
- outputs = model (
51
- input_ids = input_ids ,
52
- attention_mask = attention_mask ,
53
- token_type_ids = token_type_ids ,
47
+ input_ids = encoding ["input_ids" ].to (device )
48
+ attention_mask = encoding ["attention_mask" ].to (device )
49
+ token_type_ids = torch .zeros_like (input_ids ).to (device )
50
+ offset_mapping = encoding ["offset_mapping" ].squeeze ().tolist ()
51
+ tokens = tokenizer .convert_ids_to_tokens (input_ids .squeeze ())
52
+
53
+ with torch .no_grad ():
54
+ outputs = model (
55
+ input_ids = input_ids ,
56
+ attention_mask = attention_mask ,
57
+ token_type_ids = token_type_ids ,
58
+ )
59
+ pred_label_ids = outputs ["pred_label_ids" ].cpu ().numpy ()[0 ]
60
+
61
+ else : # exceed_strategy == "sliding_window":
62
+ full_encoding = tokenizer (
63
+ text ,
64
+ add_special_tokens = False ,
65
+ return_offsets_mapping = True ,
66
+ return_tensors = "pt"
54
67
)
55
68
56
- pred_label_ids = outputs ["pred_label_ids" ].cpu ().numpy ()[0 ]
69
+ input_ids = full_encoding ["input_ids" ].to (device )
70
+ tokens = full_encoding .tokens ()
71
+ attention_mask = full_encoding ["attention_mask" ].to (device )
72
+ token_type_ids = torch .zeros_like (input_ids ).to (device )
73
+ offset_mapping = full_encoding ["offset_mapping" ].squeeze ().tolist ()
74
+
75
+ if len (tokens ) <= max_len :
76
+ with torch .no_grad ():
77
+ outputs = model (
78
+ input_ids = input_ids ,
79
+ attention_mask = attention_mask ,
80
+ token_type_ids = token_type_ids ,
81
+ )
82
+ pred_label_ids = outputs ["pred_label_ids" ].cpu ().numpy ()[0 ]
83
+
84
+ else :
85
+ window_size = max_len
86
+ stride = window_size // 2
87
+
88
+ start_token_idx = 0
89
+ while True :
90
+ end_token_idx = min (start_token_idx + window_size , len (tokens ))
91
+ with torch .no_grad ():
92
+ window_pred_label_ids = model (
93
+ input_ids = input_ids [start_token_idx :end_token_idx ],
94
+ attention_mask = attention_mask [start_token_idx :end_token_idx ],
95
+ token_type_ids = token_type_ids [start_token_idx :end_token_idx ],
96
+ )["pred_label_ids" ].cpu ().numpy ()[0 ]
97
+ if end_token_idx >= len (tokens ):
98
+ pred_label_ids .extend (window_pred_label_ids ) # 最后一个窗口全部保留
99
+ break
100
+ else :
101
+ pred_label_ids .extend (window_pred_label_ids [0 : stride ]) # 只保留每个窗口的前 stride 部分
102
+
103
+ start_token_idx += stride
57
104
58
105
entities = []
59
106
char_labels = ["O" ] * len (text )
@@ -64,10 +111,10 @@ def _ner_predict(
64
111
offset = offset_mapping [i ]
65
112
if label .startswith ("B-" ):
66
113
char_labels [offset [0 ]] = label
67
- char_labels [offset [0 ]+ 1 : offset [1 ]] = ["I-" + label [2 :]] * (offset [1 ] - offset [0 ] - 1 )
114
+ char_labels [offset [0 ] + 1 : offset [1 ]] = ["I-" + label [2 :]] * (offset [1 ] - offset [0 ] - 1 )
68
115
elif label .startswith ("I-" ):
69
116
char_labels [offset [0 ]: offset [1 ]] = [label ] * (offset [1 ] - offset [0 ])
70
-
117
+
71
118
# 从 char_labels 中推断实体
72
119
i = 0
73
120
while i < len (char_labels ):
@@ -76,31 +123,31 @@ def _ner_predict(
76
123
start = i
77
124
i += 1
78
125
while i < len (char_labels ) and (
79
- char_labels [i ] == f"I-{ entity_type } "
80
- or (char_labels [i ] == f"O" and tokens [i ].startswith ("##" ))
126
+ char_labels [i ] == f"I-{ entity_type } "
127
+ or (char_labels [i ] == f"O" and tokens [i ].startswith ("##" ))
81
128
):
82
129
i += 1
83
130
end = i
84
- entities .append ({"start" : start ,
85
- "end" : end ,
86
- "type" : entity_type ,
87
- "text" : text [start :end ]}) # 缺少 id
131
+ entities .append ({"start" : start ,
132
+ "end" : end ,
133
+ "type" : entity_type ,
134
+ "text" : text [start :end ]}) # 缺少 id
88
135
else :
89
136
i += 1
90
137
return entities
91
138
92
139
93
140
def _re_predict (
94
- text ,
95
- e1 ,
96
- e2 ,
97
- model ,
98
- tokenizer ,
99
- max_len ,
100
- device
141
+ text ,
142
+ e1 ,
143
+ e2 ,
144
+ model ,
145
+ tokenizer ,
146
+ max_len ,
147
+ device
101
148
):
102
149
model .to (device )
103
-
150
+
104
151
encoding = tokenizer (
105
152
text ,
106
153
add_special_tokens = False ,
@@ -114,10 +161,10 @@ def _re_predict(
114
161
attention_mask = encoding ["attention_mask" ].to (device )
115
162
token_type_ids = torch .zeros_like (input_ids ).to (device )
116
163
offset_mapping = encoding ["offset_mapping" ].squeeze ().tolist ()
117
-
164
+
118
165
e1_mask = _create_entity_mask (input_ids , offset_mapping , e1 [0 ], e1 [1 ])
119
166
e2_mask = _create_entity_mask (input_ids , offset_mapping , e2 [0 ], e2 [1 ])
120
-
167
+
121
168
with torch .no_grad ():
122
169
outputs = model (
123
170
input_ids = input_ids ,
@@ -126,24 +173,25 @@ def _re_predict(
126
173
e1_mask = e1_mask ,
127
174
e2_mask = e2_mask
128
175
)
129
-
176
+
130
177
logits = outputs ["logits" ]
131
178
probs = torch .nn .functional .softmax (logits , dim = 1 ).cpu ().numpy ()[0 ]
132
-
179
+
133
180
pred_idx = logits .argmax (dim = 1 ).item ()
134
181
relation = id2relation [pred_idx ]
135
182
probability = probs [pred_idx ]
136
- return {"source" : text [e1 [0 ]:e1 [1 ]],
137
- "target" : text [e2 [0 ]:e2 [1 ]],
138
- "type" : relation ,
183
+ return {"source" : text [e1 [0 ]:e1 [1 ]],
184
+ "target" : text [e2 [0 ]:e2 [1 ]],
185
+ "type" : relation ,
139
186
"probability" : float (probability )} # 缺少 source_id 和 target_id, 多 probability
140
187
141
188
142
189
def ner_predict (
143
- text : str ,
144
- model_path : str = "experimental/scripts/ke/checkpoints/ner/final_model" ,
145
- max_len : int = 512 ,
146
- device : str = "cuda" if torch .cuda .is_available () else "cpu" ,
190
+ text : str ,
191
+ model_path : str = "experimental/scripts/ke/checkpoints/ner/final_model" ,
192
+ max_len : int = 512 ,
193
+ device : str = "cuda" if torch .cuda .is_available () else "cpu" ,
194
+ exceed_strategy : str = "truncation"
147
195
) -> list [dict ]:
148
196
"""
149
197
使用模型进行实体预测
@@ -158,16 +206,17 @@ def ner_predict(
158
206
list[dict]: 预测结果列表
159
207
"""
160
208
model , tokenizer = load_model (model_path , "ner" )
161
- return _ner_predict (text , model , tokenizer , max_len , device )
162
-
209
+ return _ner_predict (text , model , tokenizer , max_len , device , exceed_strategy )
210
+
163
211
164
212
def re_predict (
165
- text : str ,
166
- e1_range : tuple [int , int ],
167
- e2_range : tuple [int , int ],
168
- model_path : str = "experimental/scripts/ke/checkpoints/re/final_model" ,
169
- max_len : int = 512 ,
170
- device : str = "cuda" if torch .cuda .is_available () else "cpu" ,
213
+ text : str ,
214
+ e1_range : tuple [int , int ],
215
+ e2_range : tuple [int , int ],
216
+ model_path : str = "experimental/scripts/ke/checkpoints/re/final_model" ,
217
+ max_len : int = 512 ,
218
+ device : str = "cuda" if torch .cuda .is_available () else "cpu" ,
219
+ exceed_strategy : str = "truncation"
171
220
) -> dict :
172
221
"""
173
222
使用模型进行关系预测
@@ -184,4 +233,4 @@ def re_predict(
184
233
dict: 预测结果
185
234
"""
186
235
model , tokenizer = load_model (model_path , "re" )
187
- return _re_predict (text , e1_range , e2_range , model , tokenizer , max_len , device )
236
+ return _re_predict (text , e1_range , e2_range , model , tokenizer , max_len , device , exceed_strategy )
0 commit comments