Skip to content

Commit 2313e3b

Browse files
committed
FIX: merge conflict
2 parents 055875e + 2650bb9 commit 2313e3b

File tree

28 files changed

+446
-56
lines changed

28 files changed

+446
-56
lines changed

crslab/config/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def __init__(self, config_file, gpu='-1', debug=False):
3636
self.opt = self.load_yaml_configs(config_file)
3737
# gpu
3838
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
39+
gpu = gpu.split(",")
40+
for i in range(len(gpu)):
41+
gpu[i] = int(gpu[i])
42+
self.opt["gpu"] = gpu
3943
# dataset
4044
dataset = self.opt['dataset']
4145
tokenize = self.opt['tokenize']

crslab/model/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
99

1010
from loguru import logger
11+
import torch
1112

1213
from .conversation import *
1314
from .crs import *
@@ -43,6 +44,13 @@ def get_model(config, model_name, device, vocab, side_data=None):
4344
if model_name in Model_register_table:
4445
model = Model_register_table[model_name](config, device, vocab, side_data)
4546
logger.info(f'[Build model {model_name}]')
46-
return model
47+
if config.opt["gpu"] == [-1]:
48+
return model
49+
else:
50+
if len(config.opt["gpu"]) > 1 and model_name == 'PMI':
51+
logger.info(f'[PMI model does not support multi GPUs yet, using single GPU now]')
52+
return model.to(device)
53+
return torch.nn.DataParallel(model, device_ids=config["gpu"])
54+
4755
else:
4856
raise NotImplementedError('Model [{}] has not been implemented'.format(model_name))

crslab/model/conversation/gpt2/gpt2.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,8 @@ def build_model(self):
6464
self.model = GPT2LMHeadModel.from_pretrained(self.dpath)
6565
self.loss = CrossEntropyLoss(ignore_index=self.pad_id)
6666

67-
def converse(self, batch, mode):
67+
def forward(self, batch, mode):
6868
_, _, input_ids, context, _, _, y = batch
69-
7069
if mode != 'test':
7170
# torch.tensor's shape = (bs, seq_len, v_s); tuple's length = 12
7271
lm_logits = self.model(input_ids).logits
@@ -119,3 +118,44 @@ def calculate_loss(self, logit, labels):
119118

120119
loss = self.loss(logit.reshape(-1, logit.size(-1)), labels.reshape(-1))
121120
return loss
121+
122+
def generate_bs(self, context, beam=4):
123+
context = context[..., -self.response_truncate + 1:]
124+
context_former = context
125+
batch_size = context.shape[0]
126+
sequences = [[[list(), 1.0]]] * batch_size
127+
for i in range(self.response_truncate - 1):
128+
if sequences != [[[list(), 1.0]]] * batch_size:
129+
context = []
130+
for i in range(batch_size):
131+
for cand in sequences[i]:
132+
text = torch.cat((context_former[i], torch.tensor(cand[0]).to(self.device))) # 由于取消了state,与之前的context拼接
133+
context.append(text)
134+
context = torch.stack(context)
135+
with torch.no_grad():
136+
outputs = self.model(context)
137+
last_hidden_state, state = outputs.logits, outputs.past_key_values
138+
next_token_logits = last_hidden_state[:, -1, :]
139+
next_token_probs = torch.nn.functional.softmax(next_token_logits)
140+
topk = torch.topk(next_token_probs, beam, dim=-1)
141+
probs = topk.values.reshape([batch_size, -1, beam]) # (bs, candidate, beam)
142+
preds = topk.indices.reshape([batch_size, -1, beam]) # (bs, candidate, beam)
143+
144+
for j in range(batch_size):
145+
all_candidates = []
146+
for n in range(len(sequences[j])):
147+
for k in range(beam):
148+
seq = sequences[j][n][0]
149+
prob = sequences[j][n][1]
150+
seq_tmp = seq.copy()
151+
seq_tmp.append(preds[j][n][k])
152+
candidate = [seq_tmp, prob * probs[j][n][k]]
153+
all_candidates.append(candidate)
154+
ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
155+
sequences[j] = ordered[:beam]
156+
157+
res = []
158+
for i in range(batch_size):
159+
res.append(torch.stack(sequences[i][0][0]))
160+
res = torch.stack(res)
161+
return res

crslab/model/conversation/transformer/transformer.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,61 @@ def _decode_greedy_with_kg(self, token_encoding):
190190
logits = torch.cat(logits, dim=1)
191191
return logits, inputs
192192

193-
def converse(self, batch, mode):
193+
def _decode_beam_search_with_kg(self, token_encoding, beam=4):
194+
batch_size = token_encoding[0].shape[0]
195+
xs = self._starts(batch_size).long().reshape(1, batch_size, -1)
196+
incr_state = None
197+
sequences = [[[list(), list(), 1.0]]] * batch_size
198+
for i in range(self.longest_label):
199+
# at beginning there is 1 candidate, when i!=0 there are 4 candidates
200+
if i == 1:
201+
token_encoding = (token_encoding[0].repeat(beam, 1, 1),
202+
token_encoding[1].repeat(beam, 1, 1))
203+
if i != 0:
204+
xs = []
205+
for d in range(len(sequences[0])):
206+
for j in range(batch_size):
207+
text = sequences[j][d][0]
208+
xs.append(text)
209+
xs = torch.stack(xs).reshape(beam, batch_size, -1) # (beam, batch_size, _)
210+
211+
dialog_latent, incr_state = self.conv_decoder(xs.reshape(len(sequences[0]) * batch_size, -1),
212+
token_encoding,
213+
incr_state)
214+
dialog_latent = dialog_latent[:, -1:, :] # (bs, 1, dim)
215+
gen_logits = F.linear(dialog_latent, self.token_embedding.weight)
216+
217+
logits = gen_logits.reshape(len(sequences[0]), batch_size, 1, -1)
218+
# turn into probabilities,in case of negative numbers
219+
probs, preds = torch.nn.functional.softmax(logits).topk(beam, dim=-1)
220+
221+
# (candeidate, bs, 1 , beam) during first loop, candidate=1, otherwise candidate=beam
222+
223+
for j in range(batch_size):
224+
all_candidates = []
225+
for n in range(len(sequences[j])):
226+
for k in range(beam):
227+
prob = sequences[j][n][2]
228+
logit = sequences[j][n][1]
229+
if logit == []:
230+
logit_tmp = logits[n][j][0].unsqueeze(0)
231+
else:
232+
logit_tmp = torch.cat((logit, logits[n][j][0].unsqueeze(0)), dim=0)
233+
seq_tmp = torch.cat((xs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1)))
234+
candidate = [seq_tmp, logit_tmp, prob * probs[n][j][0][k]]
235+
all_candidates.append(candidate)
236+
ordered = sorted(all_candidates, key=lambda tup: tup[2], reverse=True)
237+
sequences[j] = ordered[:beam]
238+
239+
# check if everyone has generated an end token
240+
all_finished = ((xs == self.end_token_idx).sum(dim=1) > 0).sum().item() == batch_size
241+
if all_finished:
242+
break
243+
logits = torch.stack([seq[0][1] for seq in sequences])
244+
xs = torch.stack([seq[0][0] for seq in sequences])
245+
return logits, xs
246+
247+
def forward(self, batch, mode):
194248
context_tokens, context_entities, context_words, response = batch
195249

196250
# encoder-decoder

crslab/model/crs/kbrd/kbrd.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def __init__(self, opt, device, vocab, side_data):
7171
side_data (dict): A dictionary record the side data.
7272
7373
"""
74+
self.device = device
75+
self.gpu = opt.get("gpu", -1)
7476
# vocab
7577
self.pad_token_idx = vocab['pad']
7678
self.start_token_idx = vocab['start']
@@ -172,7 +174,7 @@ def _build_conversation_layer(self):
172174
def encode_user(self, entity_lists, kg_embedding):
173175
user_repr_list = []
174176
for entity_list in entity_lists:
175-
if not entity_list:
177+
if entity_list is not None:
176178
user_repr_list.append(torch.zeros(self.user_emb_dim, device=self.device))
177179
continue
178180
user_repr = kg_embedding[entity_list]
@@ -205,17 +207,18 @@ def decode_forced(self, encoder_states, user_embedding, resp):
205207
return sum_logits, preds
206208

207209
def decode_greedy(self, encoder_states, user_embedding):
210+
208211
bsz = encoder_states[0].shape[0]
209212
xs = self._starts(bsz)
210213
incr_state = None
211214
logits = []
212215
for i in range(self.longest_label):
213-
scores, incr_state = self.decoder(xs, encoder_states, incr_state)
216+
scores, incr_state = self.decoder(xs, encoder_states, incr_state) # incr_state is always None
214217
scores = scores[:, -1:, :]
215218
token_logits = F.linear(scores, self.token_embedding.weight)
216219
user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1)
217220
sum_logits = token_logits + user_logits
218-
_, preds = sum_logits.max(dim=-1)
221+
probs, preds = sum_logits.max(dim=-1)
219222
logits.append(scores)
220223
xs = torch.cat([xs, preds], dim=1)
221224
# check if everyone has generated an end token
@@ -225,6 +228,62 @@ def decode_greedy(self, encoder_states, user_embedding):
225228
logits = torch.cat(logits, 1)
226229
return logits, xs
227230

231+
def decode_beam_search(self, encoder_states, user_embedding, beam=4):
232+
bsz = encoder_states[0].shape[0]
233+
xs = self._starts(bsz).reshape(1, bsz, -1) # (batch_size, _)
234+
sequences = [[[list(), list(), 1.0]]] * bsz
235+
for i in range(self.longest_label):
236+
# at beginning there is 1 candidate, when i!=0 there are 4 candidates
237+
if i != 0:
238+
xs = []
239+
for d in range(len(sequences[0])):
240+
for j in range(bsz):
241+
text = sequences[j][d][0]
242+
xs.append(text)
243+
xs = torch.stack(xs).reshape(beam, bsz, -1) # (beam, batch_size, _)
244+
245+
with torch.no_grad():
246+
if i == 1:
247+
user_embedding = user_embedding.repeat(beam, 1)
248+
encoder_states = (encoder_states[0].repeat(beam, 1, 1),
249+
encoder_states[1].repeat(beam, 1, 1))
250+
251+
scores, _ = self.decoder(xs.reshape(len(sequences[0])*bsz, -1), encoder_states)
252+
scores = scores[:, -1:, :]
253+
token_logits = F.linear(scores, self.token_embedding.weight)
254+
user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1)
255+
sum_logits = token_logits + user_logits
256+
257+
logits = sum_logits.reshape(len(sequences[0]), bsz, 1, -1)
258+
scores = scores.reshape(len(sequences[0]), bsz, 1, -1)
259+
logits = torch.nn.functional.softmax(logits) # turn into probabilities,in case of negative numbers
260+
probs, preds = logits.topk(beam, dim=-1)
261+
# (candeidate, bs, 1 , beam) during first loop, candidate=1, otherwise candidate=beam
262+
263+
for j in range(bsz):
264+
all_candidates = []
265+
for n in range(len(sequences[j])):
266+
for k in range(beam):
267+
prob = sequences[j][n][2]
268+
score = sequences[j][n][1]
269+
if score == []:
270+
score_tmp = scores[n][j][0].unsqueeze(0)
271+
else:
272+
score_tmp = torch.cat((score, scores[n][j][0].unsqueeze(0)), dim=0)
273+
seq_tmp = torch.cat((xs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1)))
274+
candidate = [seq_tmp, score_tmp, prob * probs[n][j][0][k]]
275+
all_candidates.append(candidate)
276+
ordered = sorted(all_candidates, key=lambda tup: tup[2], reverse=True)
277+
sequences[j] = ordered[:beam]
278+
279+
# check if everyone has generated an end token
280+
all_finished = ((xs == self.end_token_idx).sum(dim=1) > 0).sum().item() == bsz
281+
if all_finished:
282+
break
283+
logits = torch.stack([seq[0][1] for seq in sequences])
284+
xs = torch.stack([seq[0][0] for seq in sequences])
285+
return logits, xs
286+
228287
def converse(self, batch, mode):
229288
context_tokens, context_entities, response = batch['context_tokens'], batch['context_entities'], batch[
230289
'response']
@@ -240,3 +299,12 @@ def converse(self, batch, mode):
240299
else:
241300
_, preds = self.decode_greedy(encoder_state, user_embedding)
242301
return preds
302+
303+
def forward(self, batch, mode, stage):
304+
if len(self.gpu) >= 2:
305+
self.edge_idx = self.edge_idx.cuda(torch.cuda.current_device())
306+
self.edge_type = self.edge_type.cuda(torch.cuda.current_device())
307+
if stage == "conv":
308+
return self.converse(batch, mode)
309+
if stage == "rec":
310+
return self.recommend(batch, mode)

0 commit comments

Comments
 (0)