@@ -71,6 +71,8 @@ def __init__(self, opt, device, vocab, side_data):
71
71
side_data (dict): A dictionary record the side data.
72
72
73
73
"""
74
+ self .device = device
75
+ self .gpu = opt .get ("gpu" , - 1 )
74
76
# vocab
75
77
self .pad_token_idx = vocab ['pad' ]
76
78
self .start_token_idx = vocab ['start' ]
@@ -172,7 +174,7 @@ def _build_conversation_layer(self):
172
174
def encode_user (self , entity_lists , kg_embedding ):
173
175
user_repr_list = []
174
176
for entity_list in entity_lists :
175
- if not entity_list :
177
+ if entity_list is not None :
176
178
user_repr_list .append (torch .zeros (self .user_emb_dim , device = self .device ))
177
179
continue
178
180
user_repr = kg_embedding [entity_list ]
@@ -205,17 +207,18 @@ def decode_forced(self, encoder_states, user_embedding, resp):
205
207
return sum_logits , preds
206
208
207
209
def decode_greedy (self , encoder_states , user_embedding ):
210
+
208
211
bsz = encoder_states [0 ].shape [0 ]
209
212
xs = self ._starts (bsz )
210
213
incr_state = None
211
214
logits = []
212
215
for i in range (self .longest_label ):
213
- scores , incr_state = self .decoder (xs , encoder_states , incr_state )
216
+ scores , incr_state = self .decoder (xs , encoder_states , incr_state ) # incr_state is always None
214
217
scores = scores [:, - 1 :, :]
215
218
token_logits = F .linear (scores , self .token_embedding .weight )
216
219
user_logits = self .user_proj_2 (torch .relu (self .user_proj_1 (user_embedding ))).unsqueeze (1 )
217
220
sum_logits = token_logits + user_logits
218
- _ , preds = sum_logits .max (dim = - 1 )
221
+ probs , preds = sum_logits .max (dim = - 1 )
219
222
logits .append (scores )
220
223
xs = torch .cat ([xs , preds ], dim = 1 )
221
224
# check if everyone has generated an end token
@@ -225,6 +228,62 @@ def decode_greedy(self, encoder_states, user_embedding):
225
228
logits = torch .cat (logits , 1 )
226
229
return logits , xs
227
230
231
+ def decode_beam_search (self , encoder_states , user_embedding , beam = 4 ):
232
+ bsz = encoder_states [0 ].shape [0 ]
233
+ xs = self ._starts (bsz ).reshape (1 , bsz , - 1 ) # (batch_size, _)
234
+ sequences = [[[list (), list (), 1.0 ]]] * bsz
235
+ for i in range (self .longest_label ):
236
+ # at beginning there is 1 candidate, when i!=0 there are 4 candidates
237
+ if i != 0 :
238
+ xs = []
239
+ for d in range (len (sequences [0 ])):
240
+ for j in range (bsz ):
241
+ text = sequences [j ][d ][0 ]
242
+ xs .append (text )
243
+ xs = torch .stack (xs ).reshape (beam , bsz , - 1 ) # (beam, batch_size, _)
244
+
245
+ with torch .no_grad ():
246
+ if i == 1 :
247
+ user_embedding = user_embedding .repeat (beam , 1 )
248
+ encoder_states = (encoder_states [0 ].repeat (beam , 1 , 1 ),
249
+ encoder_states [1 ].repeat (beam , 1 , 1 ))
250
+
251
+ scores , _ = self .decoder (xs .reshape (len (sequences [0 ])* bsz , - 1 ), encoder_states )
252
+ scores = scores [:, - 1 :, :]
253
+ token_logits = F .linear (scores , self .token_embedding .weight )
254
+ user_logits = self .user_proj_2 (torch .relu (self .user_proj_1 (user_embedding ))).unsqueeze (1 )
255
+ sum_logits = token_logits + user_logits
256
+
257
+ logits = sum_logits .reshape (len (sequences [0 ]), bsz , 1 , - 1 )
258
+ scores = scores .reshape (len (sequences [0 ]), bsz , 1 , - 1 )
259
+ logits = torch .nn .functional .softmax (logits ) # turn into probabilities,in case of negative numbers
260
+ probs , preds = logits .topk (beam , dim = - 1 )
261
+ # (candeidate, bs, 1 , beam) during first loop, candidate=1, otherwise candidate=beam
262
+
263
+ for j in range (bsz ):
264
+ all_candidates = []
265
+ for n in range (len (sequences [j ])):
266
+ for k in range (beam ):
267
+ prob = sequences [j ][n ][2 ]
268
+ score = sequences [j ][n ][1 ]
269
+ if score == []:
270
+ score_tmp = scores [n ][j ][0 ].unsqueeze (0 )
271
+ else :
272
+ score_tmp = torch .cat ((score , scores [n ][j ][0 ].unsqueeze (0 )), dim = 0 )
273
+ seq_tmp = torch .cat ((xs [n ][j ].reshape (- 1 ), preds [n ][j ][0 ][k ].reshape (- 1 )))
274
+ candidate = [seq_tmp , score_tmp , prob * probs [n ][j ][0 ][k ]]
275
+ all_candidates .append (candidate )
276
+ ordered = sorted (all_candidates , key = lambda tup : tup [2 ], reverse = True )
277
+ sequences [j ] = ordered [:beam ]
278
+
279
+ # check if everyone has generated an end token
280
+ all_finished = ((xs == self .end_token_idx ).sum (dim = 1 ) > 0 ).sum ().item () == bsz
281
+ if all_finished :
282
+ break
283
+ logits = torch .stack ([seq [0 ][1 ] for seq in sequences ])
284
+ xs = torch .stack ([seq [0 ][0 ] for seq in sequences ])
285
+ return logits , xs
286
+
228
287
def converse (self , batch , mode ):
229
288
context_tokens , context_entities , response = batch ['context_tokens' ], batch ['context_entities' ], batch [
230
289
'response' ]
@@ -240,3 +299,12 @@ def converse(self, batch, mode):
240
299
else :
241
300
_ , preds = self .decode_greedy (encoder_state , user_embedding )
242
301
return preds
302
+
303
+ def forward (self , batch , mode , stage ):
304
+ if len (self .gpu ) >= 2 :
305
+ self .edge_idx = self .edge_idx .cuda (torch .cuda .current_device ())
306
+ self .edge_type = self .edge_type .cuda (torch .cuda .current_device ())
307
+ if stage == "conv" :
308
+ return self .converse (batch , mode )
309
+ if stage == "rec" :
310
+ return self .recommend (batch , mode )
0 commit comments