Skip to content

Commit 3ba049a

Browse files
authored
Update stable_behavior_policy.py
1 parent 269dc9a commit 3ba049a

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

nova/stable_behavior_policy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,11 @@ def learn(self, batch, t_env):
186186
# => [n_thread, num_history, max_history_len, max_vehicle_num, obs_dim]
187187
agent_history = history[:, :, i]
188188
agent_behavior_latent = behavior_latent[:, :, i]
189-
mask = 1 - agent_terminate[:, :, i, 0]
189+
190+
if self.args.env == "MPE":
191+
mask = 1 - agent_terminate[:, :, i, 0]
192+
else:
193+
mask = agent_terminate[:, :, i, 0]
190194

191195
# Prev latent / Updated latent: [n_thread, max_vehicle_num, latent_dim]
192196
latent = torch.zeros((n_thread, max_vehicle_num, self.latent_dim)).to(self.device)

0 commit comments

Comments
 (0)