@@ -233,19 +233,33 @@ def accuracy_score(dset, graphs, model, modelEnc, num_objects = 0, verbose = Fal
233
233
goal_num , world_num , tools , g , t = graph
234
234
if 'gcn_seq' in training :
235
235
actionSeq , graphSeq = g ; loss = 0 ; toolSeq = tools
236
- for i , g in enumerate (graphSeq ):
237
- y_pred = model (g , goal2vec [goal_num ], goalObjects2vec [goal_num ], tool_vec )
238
- y_true = torch .zeros (NUMTOOLS )
239
- y_true [TOOLS .index (toolSeq [i ])] = 1
240
- total_test_loss += l (y_pred .view (1 ,- 1 ), y_true )
241
- y_pred = list (y_pred .reshape (- 1 ))
242
- # tools_possible = dset.goal_scene_to_tools[(goal_num,world_num)]
243
- tool_predicted = TOOLS [y_pred .index (max (y_pred ))]
244
- if tool_predicted == toolSeq [i ]:
245
- total_correct += 1
246
- elif verbose :
247
- print (goal_num , world_num , tool_predicted , toolSeq [i ])
248
- denominator += 1
236
+ if 'Tseq' in model .name :
237
+ y_pred = model (graphSeq , goal2vec [goal_num ], goalObjects2vec [goal_num ], tool_vec , tools )
238
+ for i in range (len (y_pred )):
239
+ y_pred_i = list (y_pred [i ].reshape (- 1 ))
240
+ tool_predicted = TOOLS [y_pred_i .index (max (y_pred_i ))]
241
+ y_true = torch .zeros (NUMTOOLS )
242
+ y_true [TOOLS .index (toolSeq [i ])] = 1
243
+ total_test_loss += l (y_pred [i ].view (1 ,- 1 ), y_true )
244
+ if tool_predicted == toolSeq [i ]:
245
+ total_correct += 1
246
+ elif verbose :
247
+ print (goal_num , world_num , tool_predicted , toolSeq [i ])
248
+ denominator += 1
249
+ else :
250
+ for i , g in enumerate (graphSeq ):
251
+ y_pred = model (g , goal2vec [goal_num ], goalObjects2vec [goal_num ], tool_vec )
252
+ y_true = torch .zeros (NUMTOOLS )
253
+ y_true [TOOLS .index (toolSeq [i ])] = 1
254
+ total_test_loss += l (y_pred .view (1 ,- 1 ), y_true )
255
+ y_pred = list (y_pred .reshape (- 1 ))
256
+ # tools_possible = dset.goal_scene_to_tools[(goal_num,world_num)]
257
+ tool_predicted = TOOLS [y_pred .index (max (y_pred ))]
258
+ if tool_predicted == toolSeq [i ]:
259
+ total_correct += 1
260
+ elif verbose :
261
+ print (goal_num , world_num , tool_predicted , toolSeq [i ])
262
+ denominator += 1
249
263
continue
250
264
elif 'gcn' in training :
251
265
y_pred = model (g , goal2vec [goal_num ], goalObjects2vec [goal_num ], tool_vec )
@@ -394,13 +408,20 @@ def backprop(data, optimizer, graphs, model, num_objects, modelEnc=None, batch_s
394
408
for iter_num , graph in tqdm (list (enumerate (graphs )), ncols = 80 ):
395
409
goal_num , world_num , tools , g , t = graph
396
410
if 'gcn_seq' in training :
397
- actionSeq , graphSeq = g ; loss = 0 ; toolSeq = tools
398
- for i , g in enumerate (graphSeq ):
399
- y_pred = model (g , goal2vec [goal_num ], goalObjects2vec [goal_num ], tool_vec )
400
- y_true = torch .zeros (NUMTOOLS )
401
- y_true [TOOLS .index (tools [i ])] = 1
402
- loss += l (y_pred .view (1 ,- 1 ), y_true )
403
- if weighted : loss *= (1 if t == data .min_time [(goal_num , world_num )] else 0.5 )
411
+ actionSeq , graphSeq = g ; loss = 0
412
+ if 'Tseq' in model .name :
413
+ y_pred = model (graphSeq , goal2vec [goal_num ], goalObjects2vec [goal_num ], tool_vec , tools )
414
+ for i in range (len (y_pred )):
415
+ y_true = torch .zeros (NUMTOOLS )
416
+ y_true [TOOLS .index (tools [i ])] = 1
417
+ loss += l (y_pred [i ].view (1 ,- 1 ), y_true )
418
+ else :
419
+ for i ,g in enumerate (graphSeq ):
420
+ y_pred = model (g , goal2vec [goal_num ], goalObjects2vec [goal_num ], tool_vec )
421
+ y_true = torch .zeros (NUMTOOLS )
422
+ y_true [TOOLS .index (tools [i ])] = 1
423
+ loss += l (y_pred .view (1 ,- 1 ), y_true )
424
+ if weighted : loss *= (1 if t == data .min_time [(goal_num , world_num )] else 0.5 )
404
425
batch_loss += loss
405
426
elif 'gcn' in training :
406
427
y_pred = model (g , goal2vec [goal_num ], goalObjects2vec [goal_num ], tool_vec )
@@ -518,7 +539,10 @@ def get_model(model_name):
518
539
if training == 'gcn' or training == 'gcn_seq' :
519
540
size , layers = (4 , 5 ) if training == 'gcn' else (2 , 3 )
520
541
modelEnc = None
521
- if ("Final" not in model_name and "_NT" in model_name ) or "Final_W" in model_name :
542
+ if "Tseq" in model_name :
543
+ model_class = getattr (src .GNN .models , "GGCN_Metric_Attn_L_NT_Tseq_C" )
544
+ model = model_class (data .features , data .num_objects , size * GRAPH_HIDDEN , NUMTOOLS , layers , etypes , torch .tanh , 0.5 , embedding , weighted )
545
+ elif ("Final" not in model_name and "_NT" in model_name ) or "Final_W" in model_name :
522
546
model_class = getattr (src .GNN .models , "DGL_Simple_Likelihood" )
523
547
model = model_class (data .features , data .num_objects , size * GRAPH_HIDDEN , NUMTOOLS , layers , etypes , torch .tanh , 0.5 , embedding , weighted )
524
548
else :
0 commit comments