Skip to content

Commit 7342a89

Browse files
authored
Update behavior_policy.py
1 parent 3ba049a commit 7342a89

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

nova/behavior_policy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,11 @@ def learn(self, batch, t_env):
142142
agent_history = agent_history.reshape(n_thread, num_history, self.max_history_len, max_vehicle_num, obs_dim)
143143

144144
cut_len = (num_history - 1) * self.max_history_len
145-
mask_cut = agent_terminate[:, :cut_len, i, 0]
145+
146+
if self.args.env == "MPE":
147+
mask_cut = 1 - agent_terminate[:, :cut_len, i, 0]
148+
else:
149+
mask_cut = agent_terminate[:, :cut_len, i, 0]
146150

147151
# Mask over given episode
148152
mask = mask_cut.reshape(n_thread, cut_len, 1, 1).tile(1, 1, max_vehicle_num, obs_dim)

0 commit comments

Comments
 (0)