Skip to content

Commit b461a83

Browse files
committed
feat: 滑动窗户防止文本被截断
1 parent f4f3e5b commit b461a83

File tree

11 files changed

+891
-2373
lines changed

11 files changed

+891
-2373
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,6 @@ src/course_graph/database/faiss_index
6464
src/course_graph/database/milvus.db
6565
libreoffice_convert.log
6666
.cursorrules
67-
.ruff_cache
67+
.ruff_cache
68+
69+
train.sh

docs/tutorials/other/rust.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,25 @@
1111

1212
### 编写 Rust 代码
1313

14-
Rust 扩展代码都应该放到 `rust/src/ext` 目录下, 具体实现可参考 [PyO3 指南](https://pyo3.rs/v0.15.1/)
14+
Rust 扩展代码都应该放到 `src/lib.rs` 目录下, 具体实现可参考 [PyO3 指南](https://pyo3.rs/v0.15.1/)
1515

1616
### 导出函数
1717

18-
`rust/src/lib.rs` 中添加导出函数, 具体导出方式可参考已导出的函数部分。
18+
`src/lib.rs` 中的 `_core` 函数中添加导出函数, 具体导出方式可参考已导出的函数部分。
1919

2020
### 编写函数接口
2121

2222
为了使得 IDE 获得更好的提示, 我们可以为这些函数编写 Python 接口, 但不用编写具体的实现。
2323

24-
`rust/extension.pyi` 文件中继续添加函数接口, 包含类型标注和函数注解等信息即可。
24+
`src/course_graph/_core.pyi` 文件中继续添加函数接口, 包含类型标注和函数注解等信息即可。
2525

2626
### 编译并安装
2727

28-
确保已安装 Rust 环境、Cargo 和 Python 的 `maturin`, 然后执行:
28+
确保已安装 Rust 环境、Cargo, 然后执行:
2929

3030
```bash
31-
cd rust
32-
maturin develop
31+
source .venv/bin/activate
32+
maturin develop --uv
3333
```
3434

35-
所有编写的 Rust 扩展函数会安装到 `extension` 包下
35+
所有编写的 Rust 扩展函数会安装到 `course_graph._core` 包中

experimental/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ experimental/
1414
│ └── txt/
1515
│ └── *.txt # 原始纯文本数据
1616
├── scripts/
17-
│ ├── ner/ # 实体识别模型
18-
── overview.py # 数据概览
19-
│ └── pre_trained/ # 预训练模型
17+
│ ├── ke/ # 知识抽取模型
18+
── overview.py # 数据概览
19+
── pre_trained/ # 预训练模型
2020
├── results/ # 结果
2121
└── README.md
2222

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#### 实体识别模型训练
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#### 关系识别模型训练

experimental/scripts/ke/dataset.py

Lines changed: 101 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,79 @@
66

77
import torch
88
from torch.utils.data import Dataset
9-
from .config import *
9+
from config import *
1010
from transformers import PreTrainedTokenizerFast
1111

1212

1313
class NERDataset(Dataset):
14-
def __init__(self, data, tokenizer, max_len=128):
15-
self.data = data
14+
def __init__(self, data, tokenizer, max_len=512, exceed_strategy="truncation"):
1615
self.tokenizer = tokenizer
1716
self.max_len = max_len
17+
self.data = []
18+
19+
if exceed_strategy == "truncation":
20+
self.data = data
21+
elif exceed_strategy == "sliding_window":
22+
# only support fast tokenizer temporarily
23+
for item in data:
24+
text = item['text']
25+
entities = item['entities']
26+
27+
full_encoding = self.tokenizer(
28+
text,
29+
add_special_tokens=False,
30+
return_offsets_mapping=True,
31+
return_tensors="pt"
32+
)
33+
34+
tokens = full_encoding.tokens()
35+
offset_mapping = full_encoding["offset_mapping"].squeeze().tolist()
36+
37+
if len(tokens) <= max_len:
38+
# item['encoding'] = full_encoding
39+
self.data.append(item)
40+
continue
41+
42+
window_size = max_len
43+
stride = window_size // 2
44+
45+
start_token_idx = 0
46+
while start_token_idx < len(tokens):
47+
end_token_idx = min(start_token_idx + window_size, len(tokens))
48+
49+
# [start_token_idx, end_token_idx) ==> [start_char_idx, end_char_idx)
50+
start_char_idx = offset_mapping[start_token_idx][0]
51+
end_char_idx = offset_mapping[end_token_idx - 1][1]
52+
53+
# 对每一个窗口, 只保留完全在当前窗口内的实体 (可能会减少窗口长度)
54+
for entity in entities:
55+
# bais: 实体长度远低于 window_size 和 stride
56+
if entity['start'] <= start_char_idx < entity['end']:
57+
start_char_idx = entity['end']
58+
if entity['start'] <= end_char_idx < entity['end']:
59+
end_char_idx = entity['start']
60+
break
61+
# start_char_idx 和 end_char_idx 也应该变化,但这里不处理
62+
63+
window_entities = []
64+
for entity in entities:
65+
if entity['start'] >= start_char_idx and entity['end'] <= end_char_idx:
66+
new_entity = entity.copy()
67+
new_entity['start'] -= start_char_idx
68+
new_entity['end'] -= start_char_idx
69+
window_entities.append(new_entity)
70+
71+
window_text = text[start_char_idx:end_char_idx]
72+
window_data = {
73+
'text': window_text,
74+
'entities': window_entities # 暂时不添加 encoding
75+
}
76+
self.data.append(window_data)
77+
78+
next_token_idx = start_token_idx + stride # 重叠窗口
79+
start_token_idx = next_token_idx
80+
else:
81+
pass
1882

1983
def __len__(self):
2084
return len(self.data)
@@ -31,21 +95,24 @@ def __getitem__(self, idx):
3195
for i in range(start + 1, end):
3296
char_labels[i] = f"I-{entity_type}"
3397

98+
# char_labels 对齐为 token_labels
3499
if isinstance(self.tokenizer, PreTrainedTokenizerFast):
35-
encoding = self.tokenizer(
36-
text,
37-
add_special_tokens=False,
38-
max_length=self.max_len,
39-
padding="max_length",
40-
truncation=True,
41-
return_offsets_mapping=True,
42-
return_tensors="pt"
43-
)
44-
100+
if self.data[idx].get('encoding'):
101+
encoding = self.data[idx]['encoding'] # 预处理阶段可能得到
102+
else:
103+
encoding = self.tokenizer(
104+
text,
105+
add_special_tokens=False,
106+
max_length=self.max_len,
107+
padding="max_length",
108+
truncation=True,
109+
return_offsets_mapping=True,
110+
return_tensors="pt"
111+
)
45112
input_ids = encoding["input_ids"].squeeze()
46113
attention_mask = encoding["attention_mask"].squeeze()
47-
tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
48-
offset_mapping = encoding["offset_mapping"].squeeze().tolist()
114+
tokens = encoding.tokens()
115+
offset_mapping = encoding["offset_mapping"].squeeze().tolist() # 每个 token 在原文中的位置
49116

50117
# 从实体得到 token_labels
51118
token_labels = []
@@ -70,7 +137,7 @@ def __getitem__(self, idx):
70137

71138
input_ids = encoding["input_ids"].squeeze()
72139
attention_mask = encoding["attention_mask"].squeeze()
73-
tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
140+
tokens = encoding.tokens()
74141

75142
token_labels = []
76143
char_idx = 0
@@ -94,7 +161,7 @@ def __getitem__(self, idx):
94161

95162

96163
class REDataset(Dataset):
97-
def __init__(self, data, tokenizer, max_len=128):
164+
def __init__(self, data, tokenizer, max_len=512, exceed_strategy="truncation"):
98165
self.tokenizer = tokenizer
99166
self.max_len = max_len
100167

@@ -103,13 +170,19 @@ def __init__(self, data, tokenizer, max_len=128):
103170
text = line['text']
104171
entities = line['entities']
105172
relations = line['relations']
106-
for relation in relations:
107-
self.data.append({
108-
'text': text,
109-
'e1': next(filter(lambda x: x['id'] == relation['source_id'], entities)),
110-
'e2': next(filter(lambda x: x['id'] == relation['target_id'], entities)),
111-
'relation': relation['type']
112-
})
173+
174+
if exceed_strategy == "truncation":
175+
for relation in relations:
176+
self.data.append({
177+
'text': text,
178+
'e1': next(filter(lambda x: x['id'] == relation['source_id'], entities)),
179+
'e2': next(filter(lambda x: x['id'] == relation['target_id'], entities)),
180+
'relation': relation['type']
181+
})
182+
elif exceed_strategy == "sliding_window":
183+
pass
184+
else:
185+
pass
113186

114187
def __len__(self):
115188
return len(self.data)
@@ -139,11 +212,11 @@ def __getitem__(self, idx):
139212
attention_mask = encoding["attention_mask"].squeeze()
140213
offset_mapping = encoding["offset_mapping"].squeeze().tolist()
141214

142-
e1_mask = _create_entity_mask(input_ids, offset_mapping, e1_start, e1_end)
215+
e1_mask = _create_entity_mask(input_ids, offset_mapping, e1_start, e1_end) # 实体掩码为 1
143216
e2_mask = _create_entity_mask(input_ids, offset_mapping, e2_start, e2_end)
144217

145218
else:
146-
219+
147220
encoding = self.tokenizer(
148221
text,
149222
add_special_tokens=False,
@@ -156,10 +229,10 @@ def __getitem__(self, idx):
156229
input_ids = encoding["input_ids"].squeeze()
157230
attention_mask = encoding["attention_mask"].squeeze()
158231
tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
159-
232+
160233
e1_mask = _create_entity_mask2(text, input_ids, tokens, e1_start, e1_end)
161234
e2_mask = _create_entity_mask2(text, input_ids, tokens, e2_start, e2_end)
162-
235+
163236
return {
164237
"input_ids": input_ids,
165238
"attention_mask": attention_mask,

experimental/scripts/ke/model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,18 @@ def forward(self, input_ids, attention_mask, token_type_ids, labels=None, **kwar
4545
emissions = self.hidden2label(lstm_output) # (batch_size, max_len, num_labels)
4646

4747
pred_label_ids = self.crf.decode(emissions, mask=attention_mask.bool())
48-
pred_label_ids = torch.tensor(pred_label_ids, device=emissions.device)
48+
max_len = input_ids.shape[1]
49+
padded_label_ids = []
50+
for seq in pred_label_ids:
51+
padded_seq = seq + [0] * (max_len - len(seq)) # 使用列表操作进行填充
52+
padded_label_ids.append(padded_seq)
53+
54+
pred_label_ids = torch.tensor(padded_label_ids, device=emissions.device)
4955
pred_label_ids[pred_label_ids == 0] = 1 # 模型后处理,不允许预测出现 IGNORE
50-
pred_label_ids = F.pad(pred_label_ids, (0, input_ids.shape[1] - pred_label_ids.shape[1]), value=0, mode="constant")
5156

5257
if labels is not None:
5358
valid_mask = labels != 0
54-
loss = -self.crf(emissions, labels, mask=valid_mask)
59+
loss = -self.crf(emissions, labels, mask=valid_mask, reduction='mean')
5560
return {
5661
"loss": loss,
5762
"pred_label_ids": pred_label_ids
@@ -93,8 +98,7 @@ def forward(self, input_ids, attention_mask, token_type_ids, e1_mask, e2_mask, l
9398
concat_h = torch.cat([e1_h, e2_h], dim=-1) # (batch_size, hidden_size*2)
9499
concat_h = self.dropout(concat_h)
95100
logits = self.classifier(concat_h) # (batch_size, num_relations)
96-
97-
loss = None
101+
98102
if labels is not None:
99103
loss_fct = nn.CrossEntropyLoss()
100104
loss = loss_fct(logits.view(-1, self.num_relations), labels.view(-1))

experimental/scripts/ke/predict.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _re_predict(
142142
def ner_predict(
143143
text: str,
144144
model_path: str = "experimental/scripts/ke/checkpoints/ner/final_model",
145-
max_len: int = 128,
145+
max_len: int = 512,
146146
device: str = "cuda" if torch.cuda.is_available() else "cpu",
147147
) -> list[dict]:
148148
"""
@@ -151,7 +151,7 @@ def ner_predict(
151151
Args:
152152
texts (list[str]): 需要预测的文本列表
153153
model_path (str): 模型路径. Defaults to "experimental/scripts/ke/checkpoints/ner/final_model".
154-
max_len (int): 最大长度. Defaults to 128.
154+
max_len (int): 最大长度. Defaults to 512.
155155
device (str): 设备. Defaults to "cuda" or "cpu".
156156
157157
Returns:
@@ -166,7 +166,7 @@ def re_predict(
166166
e1_range: tuple[int, int],
167167
e2_range: tuple[int, int],
168168
model_path: str = "experimental/scripts/ke/checkpoints/re/final_model",
169-
max_len: int = 128,
169+
max_len: int = 512,
170170
device: str = "cuda" if torch.cuda.is_available() else "cpu",
171171
) -> dict:
172172
"""
@@ -177,7 +177,7 @@ def re_predict(
177177
e1_range (tuple[int, int]): 实体1的位置 [start:end]
178178
e2_range (tuple[int, int]): 实体2的位置 [start:end]
179179
model_path (str): 模型路径. Defaults to "experimental/scripts/ke/checkpoints/re/final_model".
180-
max_len (int): 最大长度. Defaults to 128.
180+
max_len (int): 最大长度. Defaults to 512.
181181
device (str): 设备. Defaults to "cuda" or "cpu".
182182
183183
Returns:

0 commit comments

Comments
 (0)