Skip to content

Commit aae0be0

Browse files
committed
Added new toolnet sequence model
1 parent ed847b4 commit aae0be0

File tree

3 files changed

+128
-21
lines changed

3 files changed

+128
-21
lines changed

src/GNN/models.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,3 +537,86 @@ def forward(self, g, goalVec, goalObjectsVec, tool_vec):
537537
probNoTool = torch.sigmoid(self.activation(self.p2(probNoTool))).flatten()
538538
output = torch.cat(((1-probNoTool)*tools.flatten(), probNoTool), dim=0)
539539
return output
540+
541+
542+
######################################################################################
543+
544+
# The following models are for tool sequence prediction
545+
546+
class GGCN_Metric_Attn_L_NT_Tseq_C(nn.Module):
547+
"""
548+
The best performing model for the sequential tool prediction task.
549+
Separate likelihood prediction of no tool for more robust tool output considering any tool
550+
to be used or not as prior.
551+
"""
552+
def __init__(self,
553+
in_feats,
554+
n_objects,
555+
n_hidden,
556+
n_classes,
557+
n_layers,
558+
etypes,
559+
activation,
560+
dropout,
561+
embedding,
562+
weighted):
563+
super(GGCN_Metric_Attn_L_NT_Tseq_C, self).__init__()
564+
self.n_classes = n_classes
565+
self.etypes = etypes
566+
self.name = "GGCN_Metric_Attn_L_NT_Tseq_C_" + str(n_hidden) + "_" + str(n_layers)
567+
self.n_hidden = n_hidden
568+
self.layers = nn.ModuleList()
569+
self.layers.append(nn.Linear(in_feats + n_objects*4, n_hidden))
570+
for i in range(n_layers - 1):
571+
self.layers.append(nn.Linear(n_hidden, n_hidden))
572+
self.tool_lstm = nn.LSTM(n_hidden, n_hidden)
573+
self.attention = nn.Linear(n_hidden + n_hidden + n_hidden, n_hidden)
574+
self.attention2 = nn.Linear(n_hidden, 1)
575+
self.embed = nn.Linear(PRETRAINED_VECTOR_SIZE, n_hidden)
576+
self.fc1 = nn.Linear(4 * n_hidden, n_hidden)
577+
self.fc2 = nn.Linear(n_hidden, n_hidden)
578+
self.fc3 = nn.Linear(n_hidden, n_hidden)
579+
self.fc4 = nn.Linear(n_hidden, 1)
580+
self.p1 = nn.Linear(3 * n_hidden, n_hidden)
581+
self.p2 = nn.Linear(n_hidden, 1)
582+
self.final = nn.Sigmoid()
583+
self.activation = nn.PReLU()
584+
585+
def forward(self, g_list, goalVec, goalObjectsVec, tool_vec, t_list):
586+
tool_embedding = self.activation(self.embed(tool_vec))
587+
t_list = [(tool_embedding[TOOLS.index(i)] if i!='no-tool' else torch.zeros(self.n_hidden)).view(1, -1) for i in t_list]
588+
lstm_hidden = (torch.randn(1, 1, self.n_hidden), torch.randn(1, 1, self.n_hidden))
589+
goalObjectsVec = self.activation(self.embed(goalObjectsVec))
590+
goal_embed = self.activation(self.embed(goalVec))
591+
predicted_tools = []
592+
for ind,g in enumerate(g_list):
593+
h = g.ndata['feat']
594+
edgeMatrices = [g.adjacency_matrix(etype=t) for t in self.etypes]
595+
edges = torch.cat(edgeMatrices, 1).to_dense()
596+
h = torch.cat((h, edges), 1)
597+
for i, layer in enumerate(self.layers):
598+
h = self.activation(layer(h))
599+
if (ind != 0):
600+
lstm_out, lstm_hidden = self.tool_lstm(t_list[ind-1].view(1,1,-1), lstm_hidden)
601+
else:
602+
lstm_out = torch.zeros(1, 1, self.n_hidden)
603+
lstm_out = lstm_out.view(-1)
604+
attn_embedding = torch.cat([h, goalObjectsVec.repeat(h.size(0)).view(h.size(0), -1), lstm_out.repeat(h.size(0)).view(h.size(0), -1)], 1)
605+
attn_embedding = self.activation(self.attention(attn_embedding))
606+
attn_weights = F.softmax(self.attention2(attn_embedding), dim=0)
607+
scene_embedding = torch.mm(attn_weights.t(), h)
608+
scene_and_goal = torch.cat([scene_embedding, goal_embed.view(1,-1), lstm_out.view(1,-1)], 1)
609+
l = []
610+
for i in range(NUMTOOLS-1):
611+
final_to_decode = torch.cat([scene_and_goal, tool_embedding[i].view(1, -1)], 1)
612+
h = self.activation(self.fc1(final_to_decode))
613+
h = self.activation(self.fc2(h))
614+
h = self.activation(self.fc3(h))
615+
h = self.final(self.fc4(h))
616+
l.append(h.flatten())
617+
tools = torch.stack(l)
618+
probNoTool = self.activation(self.p1(scene_and_goal))
619+
probNoTool = torch.sigmoid(self.activation(self.p2(probNoTool))).flatten()
620+
output = torch.cat(((1-probNoTool)*tools.flatten(), probNoTool), dim=0)
621+
predicted_tools.append(output)
622+
return predicted_tools

train.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -233,19 +233,33 @@ def accuracy_score(dset, graphs, model, modelEnc, num_objects = 0, verbose = Fal
233233
goal_num, world_num, tools, g, t = graph
234234
if 'gcn_seq' in training:
235235
actionSeq, graphSeq = g; loss = 0; toolSeq = tools
236-
for i, g in enumerate(graphSeq):
237-
y_pred = model(g, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec)
238-
y_true = torch.zeros(NUMTOOLS)
239-
y_true[TOOLS.index(toolSeq[i])] = 1
240-
total_test_loss += l(y_pred.view(1,-1), y_true)
241-
y_pred = list(y_pred.reshape(-1))
242-
# tools_possible = dset.goal_scene_to_tools[(goal_num,world_num)]
243-
tool_predicted = TOOLS[y_pred.index(max(y_pred))]
244-
if tool_predicted == toolSeq[i]:
245-
total_correct += 1
246-
elif verbose:
247-
print (goal_num, world_num, tool_predicted, toolSeq[i])
248-
denominator += 1
236+
if 'Tseq' in model.name:
237+
y_pred = model(graphSeq, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec, tools)
238+
for i in range(len(y_pred)):
239+
y_pred_i = list(y_pred[i].reshape(-1))
240+
tool_predicted = TOOLS[y_pred_i.index(max(y_pred_i))]
241+
y_true = torch.zeros(NUMTOOLS)
242+
y_true[TOOLS.index(toolSeq[i])] = 1
243+
total_test_loss += l(y_pred[i].view(1,-1), y_true)
244+
if tool_predicted == toolSeq[i]:
245+
total_correct += 1
246+
elif verbose:
247+
print (goal_num, world_num, tool_predicted, toolSeq[i])
248+
denominator += 1
249+
else:
250+
for i, g in enumerate(graphSeq):
251+
y_pred = model(g, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec)
252+
y_true = torch.zeros(NUMTOOLS)
253+
y_true[TOOLS.index(toolSeq[i])] = 1
254+
total_test_loss += l(y_pred.view(1,-1), y_true)
255+
y_pred = list(y_pred.reshape(-1))
256+
# tools_possible = dset.goal_scene_to_tools[(goal_num,world_num)]
257+
tool_predicted = TOOLS[y_pred.index(max(y_pred))]
258+
if tool_predicted == toolSeq[i]:
259+
total_correct += 1
260+
elif verbose:
261+
print (goal_num, world_num, tool_predicted, toolSeq[i])
262+
denominator += 1
249263
continue
250264
elif 'gcn' in training:
251265
y_pred = model(g, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec)
@@ -394,13 +408,20 @@ def backprop(data, optimizer, graphs, model, num_objects, modelEnc=None, batch_s
394408
for iter_num, graph in tqdm(list(enumerate(graphs)), ncols=80):
395409
goal_num, world_num, tools, g, t = graph
396410
if 'gcn_seq' in training:
397-
actionSeq, graphSeq = g; loss = 0; toolSeq = tools
398-
for i, g in enumerate(graphSeq):
399-
y_pred = model(g, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec)
400-
y_true = torch.zeros(NUMTOOLS)
401-
y_true[TOOLS.index(tools[i])] = 1
402-
loss += l(y_pred.view(1,-1), y_true)
403-
if weighted: loss *= (1 if t == data.min_time[(goal_num, world_num)] else 0.5)
411+
actionSeq, graphSeq = g; loss = 0
412+
if 'Tseq' in model.name:
413+
y_pred = model(graphSeq, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec, tools)
414+
for i in range(len(y_pred)):
415+
y_true = torch.zeros(NUMTOOLS)
416+
y_true[TOOLS.index(tools[i])] = 1
417+
loss += l(y_pred[i].view(1,-1), y_true)
418+
else:
419+
for i,g in enumerate(graphSeq):
420+
y_pred = model(g, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec)
421+
y_true = torch.zeros(NUMTOOLS)
422+
y_true[TOOLS.index(tools[i])] = 1
423+
loss += l(y_pred.view(1,-1), y_true)
424+
if weighted: loss *= (1 if t == data.min_time[(goal_num, world_num)] else 0.5)
404425
batch_loss += loss
405426
elif 'gcn' in training:
406427
y_pred = model(g, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec)
@@ -518,7 +539,10 @@ def get_model(model_name):
518539
if training == 'gcn' or training == 'gcn_seq':
519540
size, layers = (4, 5) if training == 'gcn' else (2, 3)
520541
modelEnc = None
521-
if ("Final" not in model_name and "_NT" in model_name) or "Final_W" in model_name:
542+
if "Tseq" in model_name:
543+
model_class = getattr(src.GNN.models, "GGCN_Metric_Attn_L_NT_Tseq_C")
544+
model = model_class(data.features, data.num_objects, size * GRAPH_HIDDEN, NUMTOOLS, layers, etypes, torch.tanh, 0.5, embedding, weighted)
545+
elif ("Final" not in model_name and "_NT" in model_name) or "Final_W" in model_name:
522546
model_class = getattr(src.GNN.models, "DGL_Simple_Likelihood")
523547
model = model_class(data.features, data.num_objects, size * GRAPH_HIDDEN, NUMTOOLS, layers, etypes, torch.tanh, 0.5, embedding, weighted)
524548
else:
Binary file not shown.

0 commit comments

Comments
 (0)