6
6
7
7
import torch
8
8
from torch .utils .data import Dataset
9
- from . config import *
9
+ from config import *
10
10
from transformers import PreTrainedTokenizerFast
11
11
12
12
13
13
class NERDataset (Dataset ):
14
- def __init__ (self , data , tokenizer , max_len = 128 ):
15
- self .data = data
14
+ def __init__ (self , data , tokenizer , max_len = 512 , exceed_strategy = "truncation" ):
16
15
self .tokenizer = tokenizer
17
16
self .max_len = max_len
17
+ self .data = []
18
+
19
+ if exceed_strategy == "truncation" :
20
+ self .data = data
21
+ elif exceed_strategy == "sliding_window" :
22
+ # only support fast tokenizer temporarily
23
+ for item in data :
24
+ text = item ['text' ]
25
+ entities = item ['entities' ]
26
+
27
+ full_encoding = self .tokenizer (
28
+ text ,
29
+ add_special_tokens = False ,
30
+ return_offsets_mapping = True ,
31
+ return_tensors = "pt"
32
+ )
33
+
34
+ tokens = full_encoding .tokens ()
35
+ offset_mapping = full_encoding ["offset_mapping" ].squeeze ().tolist ()
36
+
37
+ if len (tokens ) <= max_len :
38
+ # item['encoding'] = full_encoding
39
+ self .data .append (item )
40
+ continue
41
+
42
+ window_size = max_len
43
+ stride = window_size // 2
44
+
45
+ start_token_idx = 0
46
+ while start_token_idx < len (tokens ):
47
+ end_token_idx = min (start_token_idx + window_size , len (tokens ))
48
+
49
+ # [start_token_idx, end_token_idx) ==> [start_char_idx, end_char_idx)
50
+ start_char_idx = offset_mapping [start_token_idx ][0 ]
51
+ end_char_idx = offset_mapping [end_token_idx - 1 ][1 ]
52
+
53
+ # 对每一个窗口, 只保留完全在当前窗口内的实体 (可能会减少窗口长度)
54
+ for entity in entities :
55
+ # bais: 实体长度远低于 window_size 和 stride
56
+ if entity ['start' ] <= start_char_idx < entity ['end' ]:
57
+ start_char_idx = entity ['end' ]
58
+ if entity ['start' ] <= end_char_idx < entity ['end' ]:
59
+ end_char_idx = entity ['start' ]
60
+ break
61
+ # start_char_idx 和 end_char_idx 也应该变化,但这里不处理
62
+
63
+ window_entities = []
64
+ for entity in entities :
65
+ if entity ['start' ] >= start_char_idx and entity ['end' ] <= end_char_idx :
66
+ new_entity = entity .copy ()
67
+ new_entity ['start' ] -= start_char_idx
68
+ new_entity ['end' ] -= start_char_idx
69
+ window_entities .append (new_entity )
70
+
71
+ window_text = text [start_char_idx :end_char_idx ]
72
+ window_data = {
73
+ 'text' : window_text ,
74
+ 'entities' : window_entities # 暂时不添加 encoding
75
+ }
76
+ self .data .append (window_data )
77
+
78
+ next_token_idx = start_token_idx + stride # 重叠窗口
79
+ start_token_idx = next_token_idx
80
+ else :
81
+ pass
18
82
19
83
def __len__ (self ):
20
84
return len (self .data )
@@ -31,21 +95,24 @@ def __getitem__(self, idx):
31
95
for i in range (start + 1 , end ):
32
96
char_labels [i ] = f"I-{ entity_type } "
33
97
98
+ # char_labels 对齐为 token_labels
34
99
if isinstance (self .tokenizer , PreTrainedTokenizerFast ):
35
- encoding = self .tokenizer (
36
- text ,
37
- add_special_tokens = False ,
38
- max_length = self .max_len ,
39
- padding = "max_length" ,
40
- truncation = True ,
41
- return_offsets_mapping = True ,
42
- return_tensors = "pt"
43
- )
44
-
100
+ if self .data [idx ].get ('encoding' ):
101
+ encoding = self .data [idx ]['encoding' ] # 预处理阶段可能得到
102
+ else :
103
+ encoding = self .tokenizer (
104
+ text ,
105
+ add_special_tokens = False ,
106
+ max_length = self .max_len ,
107
+ padding = "max_length" ,
108
+ truncation = True ,
109
+ return_offsets_mapping = True ,
110
+ return_tensors = "pt"
111
+ )
45
112
input_ids = encoding ["input_ids" ].squeeze ()
46
113
attention_mask = encoding ["attention_mask" ].squeeze ()
47
- tokens = self . tokenizer . convert_ids_to_tokens ( input_ids )
48
- offset_mapping = encoding ["offset_mapping" ].squeeze ().tolist ()
114
+ tokens = encoding . tokens ( )
115
+ offset_mapping = encoding ["offset_mapping" ].squeeze ().tolist () # 每个 token 在原文中的位置
49
116
50
117
# 从实体得到 token_labels
51
118
token_labels = []
@@ -70,7 +137,7 @@ def __getitem__(self, idx):
70
137
71
138
input_ids = encoding ["input_ids" ].squeeze ()
72
139
attention_mask = encoding ["attention_mask" ].squeeze ()
73
- tokens = self . tokenizer . convert_ids_to_tokens ( input_ids )
140
+ tokens = encoding . tokens ( )
74
141
75
142
token_labels = []
76
143
char_idx = 0
@@ -94,7 +161,7 @@ def __getitem__(self, idx):
94
161
95
162
96
163
class REDataset (Dataset ):
97
- def __init__ (self , data , tokenizer , max_len = 128 ):
164
+ def __init__ (self , data , tokenizer , max_len = 512 , exceed_strategy = "truncation" ):
98
165
self .tokenizer = tokenizer
99
166
self .max_len = max_len
100
167
@@ -103,13 +170,19 @@ def __init__(self, data, tokenizer, max_len=128):
103
170
text = line ['text' ]
104
171
entities = line ['entities' ]
105
172
relations = line ['relations' ]
106
- for relation in relations :
107
- self .data .append ({
108
- 'text' : text ,
109
- 'e1' : next (filter (lambda x : x ['id' ] == relation ['source_id' ], entities )),
110
- 'e2' : next (filter (lambda x : x ['id' ] == relation ['target_id' ], entities )),
111
- 'relation' : relation ['type' ]
112
- })
173
+
174
+ if exceed_strategy == "truncation" :
175
+ for relation in relations :
176
+ self .data .append ({
177
+ 'text' : text ,
178
+ 'e1' : next (filter (lambda x : x ['id' ] == relation ['source_id' ], entities )),
179
+ 'e2' : next (filter (lambda x : x ['id' ] == relation ['target_id' ], entities )),
180
+ 'relation' : relation ['type' ]
181
+ })
182
+ elif exceed_strategy == "sliding_window" :
183
+ pass
184
+ else :
185
+ pass
113
186
114
187
def __len__ (self ):
115
188
return len (self .data )
@@ -139,11 +212,11 @@ def __getitem__(self, idx):
139
212
attention_mask = encoding ["attention_mask" ].squeeze ()
140
213
offset_mapping = encoding ["offset_mapping" ].squeeze ().tolist ()
141
214
142
- e1_mask = _create_entity_mask (input_ids , offset_mapping , e1_start , e1_end )
215
+ e1_mask = _create_entity_mask (input_ids , offset_mapping , e1_start , e1_end ) # 实体掩码为 1
143
216
e2_mask = _create_entity_mask (input_ids , offset_mapping , e2_start , e2_end )
144
217
145
218
else :
146
-
219
+
147
220
encoding = self .tokenizer (
148
221
text ,
149
222
add_special_tokens = False ,
@@ -156,10 +229,10 @@ def __getitem__(self, idx):
156
229
input_ids = encoding ["input_ids" ].squeeze ()
157
230
attention_mask = encoding ["attention_mask" ].squeeze ()
158
231
tokens = self .tokenizer .convert_ids_to_tokens (input_ids )
159
-
232
+
160
233
e1_mask = _create_entity_mask2 (text , input_ids , tokens , e1_start , e1_end )
161
234
e2_mask = _create_entity_mask2 (text , input_ids , tokens , e2_start , e2_end )
162
-
235
+
163
236
return {
164
237
"input_ids" : input_ids ,
165
238
"attention_mask" : attention_mask ,
0 commit comments