Skip to content

Commit 1d25630

Browse files
committed
devkit-v0.4
1 parent 8674f92 commit 1d25630

28 files changed

+19101
-663
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@
4848

4949

5050
## Changelog <a name="changelog"></a>
51+
- **`[2024/04/03]`** NAVSIM v0.4 release
52+
- Support for test phase frames of competition
53+
- Download script for trainval
54+
- Egostatus MLP Agent and training pipeline
55+
- Refactoring, Fixes, Documentation
5156
- **`[2024/03/25]`** NAVSIM v0.3 release (official devkit version for warm-up phase)
5257
- Changes env variable NUPLAN_EXP_ROOT to NAVSIM_EXP_ROOT
5358
- Adds code for Leaderboard submission

docs/agents.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,37 @@ Let’s dig deeper into this class. It has to implement the following methods:
3030
Details on the output format can be found below.
3131

3232
**The future trajectory has to be returned as an object of type `from navsim.common.dataclasses.Trajectory`. For examples, see the constant velocity agent or the human agent.**
33+
34+
# Learning-based Agents
35+
Most likely, your agent will involve learning-based components.
36+
Navsim provides a lightweight and easy-to-use interface for training.
37+
To use it, your agent has to implement some further functionality.
38+
In addition to the methods mentioned above, you have to implement the methods below.
39+
Have a look at `navsim.agents.ego_status_mlp_agent.EgoStatusMLPAgent` for an example.
40+
41+
- `get_feature_builders()`
42+
Has to return a List of feature builders (of type `navsim.planning.training. abstract_feature_target_builder.AbstractFeatureBuilder`).
43+
FeatureBuilders take the `AgentInput` object and compute the feature tensors used for agent training and inference. One feature builder can compute multiple feature tensors. They have to be returned in a dictionary, which is then provided to the model in the forward pass.
44+
Currently, we provide the following feature builders:
45+
- EgoStateFeatureBuilder (returns a Tensor containing current velocity, acceleration and driving command)
46+
- _the list will be increased in future devkit versions_
47+
48+
- `get_target_builders()`
49+
Similar to `get_feature_builders()`, returns the target builders of type `navsim.planning.training. abstract_feature_target_builder.AbstractTargetBuilder` used in training. In contrast to feature builders, they have access to the Scene object which contains ground-truth information (instead of just the AgentInput).
50+
51+
- `forward()`
52+
The forward pass through the model. Features are provided as a dictionary which contains all the features generated by the feature builders. All tensors are already batched and on the same device as the model. The forward pass has to output a Dict of which one entry has to be "trajectory" and contain a tensor representing the future trajectory, i.e. of shape [B, T, 3], where B is the batch size, T is the number of future timesteps and 3 refers to x,y,heading.
53+
54+
- `compute_loss`()`
55+
Given the features, the targets and the model predictions, this function computes the loss used for training. The loss has to be returned as a single Tensor.
56+
57+
- `get_optimizers()`
58+
Use this function to define the optimizers used for training.
59+
Depending on wheter you want to use a learning-rate scheduler or not, this function needs to either return just an Optimizer (of type `torch.optim.Optimizer`) or a dictionary that contains the Optimizer (key: "optimizer") and the learning-rate scheduler of type `torch.optim.lr_scheduler.LRScheduler` (key: "lr_scheduler").
60+
61+
- `compute_trajectory()`
62+
In contrast to the non-learning-based Agent, you don't have to implement this function.
63+
In inference, the trajectory will automatically be computed using the feature builders and the forward method.
3364
## Inputs
3465

3566
`get_sensor_config()` can be overwritten to determine which sensors are accessible to the agent.

docs/install.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@ Navigate to the download directory and download the maps
1919
cd download && ./download_maps
2020
```
2121

22-
Next download the mini split and the test split
22+
Next download the splits you want to use.
23+
You can download the mini, trainval, test and submision_test split with the following scritps
2324
```
2425
./download_mini
26+
./download_trainval
2527
./download_test
28+
./download_competition_test
2629
```
2730

2831
**The mini split and the test split take around ~160GB and ~220GB of memory respectively**
@@ -36,9 +39,13 @@ This will download the splits into the download directory. From there, move it t
3639
   ├── maps
3740
   ├── navsim_logs
3841
| ├── test
42+
| ├── trainval
43+
| ├── competition_test
3944
   │ └── mini
4045
   └── sensor_blobs
4146
├── test
47+
├── trainval
48+
├── competition_test
4249
   └── mini
4350
```
4451
Set the required environment variables, by adding the following to your `~/.bashrc` file

docs/submission.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@ NAVSIM comes with official leaderboards on HuggingFace. The leaderboards prevent
44

55
To submit to a leaderboard you need to create a pickle file that contains a trajectory for each test scenario. NAVSIM provides a script to create such a pickle file.
66

7-
Have a look at `run_cv_submission_evaluation.sh`: this file creates the pickle file for the ConstantVelocity agent. You can run it for your own agent by replacing the `agent` override.
8-
9-
**Note that you have to set the variables `TEAM_NAME`, `AUTHORS`, `EMAIL`, `INSTITUTION`, and `COUNTRY` for your submission to be valid.**
7+
Have a look at `run_create_submission_pickle.sh`: this file creates the pickle file for the ConstantVelocity agent. You can run it for your own agent by replacing the `agent` override.
8+
Follow the [submission instructions on huggingface](https://huggingface.co/spaces/AGC2024-P/e2e-driving-2024) to upload your submission.
9+
**Note that you have to set the variables `TEAM_NAME`, `AUTHORS`, `EMAIL`, `INSTITUTION`, and `COUNTRY` in `run_create_submission_pickle.sh` to generate a valid submission file**
1010

1111
### Warm-up track
1212
The warm-up track evaluates your submission on a [warm-up leaderboard](https://huggingface.co/spaces/AGC2024-P/e2e-driving-warmup) based on the `mini` split. This allows you to test your method and get familiar with the devkit and the submisison procedure, with a less restrictive submission budget (up to 5 submissions daily). Instructions on making a submission on HuggingFace are available in the HuggingFace space. Performance on the warm-up leaderboard is not taken into consideration for determining your team's ranking for the 2024 Autonomous Grand Challenge.
13+
Use the script `run_create_submission_pickle_warmup.sh` which already contains the overrides `scene_filter=warmup_test` and `split=mini` to generate the submission file for the warmup track.
1314

14-
You should be able to obtain the same evaluation results as on the server, by running the evaluation locally with the `warmup_test` scene filter. To do so, use the override `scene_filter=warmup_test` when executing the script to run the PDM scoring (e.g., `run_cv_pdm_score_evaluation.sh` for the constant-velocity agent).
15+
You should be able to obtain the same evaluation results as on the server, by running the evaluation locally.
16+
To do so, use the overrides `scene_filter=warmup_test` when executing the script to run the PDM scoring (e.g., `run_cv_pdm_score_evaluation.sh` for the constant-velocity agent).
1517

1618
### Formal track
17-
This is the [official challenge leaderboard](https://huggingface.co/spaces/AGC2024-P/e2e-driving-2024), based on secret held-out test frames. **Details and instructions for submission will be provided soon!**
19+
This is the [official challenge leaderboard](https://huggingface.co/spaces/AGC2024-P/e2e-driving-2024), based on secret held-out test frames (see submission_test split on the install page).
20+
Use the script `run_create_submission_pickle.sh`. It will by default run with `scene_filter=competition_test` and `split=competition_test`.
21+
You only need to set your own agent with the `agent` override.

download/download_competition_test.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
wget https://huggingface.co/datasets/OpenDriveLab/OpenScene/resolve/main/openscene-v1.1/openscene_metadata_private_test_e2e.tgz
2+
tar -xzf openscene_metadata_private_test_e2e.tgz
3+
rm openscene_metadata_private_test_e2e.tgz
4+
mv competition_test competition_test_navsim_logs
5+
6+
wget https://huggingface.co/datasets/OpenDriveLab/OpenScene/resolve/main/openscene-v1.1/openscene_sensor_private_test_e2e.tgz
7+
tar -xzf openscene_sensor_private_test_e2e.tgz
8+
rm openscene_sensor_private_test_e2e.tgz
9+
mv competition_test competition_test_sensor_blobs

download/download_trainval.sh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
wget https://huggingface.co/datasets/OpenDriveLab/OpenScene/resolve/main/openscene-v1.1/openscene_metadata_trainval.tgz
2+
tar -xzf openscene_metadata_trainval.tgz
3+
rm openscene_metadata_trainval.tgz
4+
5+
for split in {0..142}; do
6+
wget https://huggingface.co/datasets/OpenDriveLab/OpenScene/resolve/main/openscene-v1.1/openscene_sensor_trainval_camera/openscene_sensor_trainval_camera_${split}.tgz
7+
echo "Extracting file openscene_sensor_trainval_camera_${split}.tgz"
8+
tar -xzf openscene_sensor_trainval_camera_${split}.tgz
9+
rm openscene_sensor_trainval_camera_${split}.tgz
10+
done
11+
12+
for split in {0..142}; do
13+
wget https://huggingface.co/datasets/OpenDriveLab/OpenScene/resolve/main/openscene-v1.1/openscene_sensor_trainval_lidar/openscene_sensor_trainval_lidar_${split}.tgz
14+
echo "Extracting file openscene_sensor_trainval_lidar_${split}.tgz"
15+
tar -xzf openscene_sensor_trainval_lidar_${split}.tgz
16+
rm openscene_sensor_trainval_lidar_${split}.tgz
17+
done
18+
19+
mv openscene-v1.1/meta_datas trainval_navsim_logs
20+
mv openscene-v1.1/sensor_blobs trainval_sensor_blobs
21+
rm -r openscene-v1.1

navsim/agents/abstract_agent.py

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,18 @@
1-
from __future__ import annotations
2-
3-
import abc
4-
5-
from abc import abstractmethod
6-
from typing import Any, List
1+
from abc import abstractmethod, ABC
2+
from typing import Dict, Union, List
3+
import torch
74

85
from navsim.common.dataclasses import AgentInput, Trajectory, SensorConfig
6+
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
97

108

11-
class AbstractAgent(abc.ABC):
12-
"""
13-
Interface for a generic end-to-end agent.
14-
"""
15-
requires_scene = False
16-
17-
def __new__(cls, *args: Any, **kwargs: Any) -> AbstractAgent:
18-
"""
19-
Define attributes needed by all agents, take care when overriding.
20-
:param cls: class being constructed.
21-
:param args: arguments to constructor.
22-
:param kwargs: keyword arguments to constructor.
23-
"""
24-
instance: AbstractAgent = super().__new__(cls)
25-
instance._compute_trajectory_runtimes = []
26-
return instance
9+
class AbstractAgent(torch.nn.Module, ABC):
10+
def __init__(
11+
self,
12+
requires_scene: bool = False,
13+
):
14+
super().__init__()
15+
self.requires_scene = requires_scene
2716

2817
@abstractmethod
2918
def name(self) -> str:
@@ -39,19 +28,78 @@ def get_sensor_config(self) -> SensorConfig:
3928
"""
4029
pass
4130

42-
@abc.abstractmethod
31+
@abstractmethod
4332
def initialize(self) -> None:
4433
"""
4534
Initialize agent
4635
:param initialization: Initialization class.
4736
"""
4837
pass
4938

50-
@abc.abstractmethod
39+
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
40+
"""
41+
Forward pass of the agent.
42+
:param features: Dictionary of features.
43+
:return: Dictionary of predictions.
44+
"""
45+
raise NotImplementedError
46+
47+
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
48+
"""
49+
:return: List of target builders.
50+
"""
51+
raise NotImplementedError("No feature builders. Agent does not support training.")
52+
53+
def get_target_builders(self) -> List[AbstractTargetBuilder]:
54+
"""
55+
:return: List of feature builders.
56+
"""
57+
raise NotImplementedError("No target builders. Agent does not support training.")
58+
5159
def compute_trajectory(self, agent_input: AgentInput) -> Trajectory:
5260
"""
5361
Computes the ego vehicle trajectory.
5462
:param current_input: Dataclass with agent inputs.
5563
:return: Trajectory representing the predicted ego's position in future
5664
"""
57-
pass
65+
features : Dict[str, torch.Tensor] = {}
66+
# build features
67+
for builder in self.get_feature_builders():
68+
features.update(builder.compute_features(agent_input))
69+
70+
# add batch dimension
71+
features = {k: v.unsqueeze(0) for k, v in features.items()}
72+
73+
# forward pass
74+
with torch.no_grad():
75+
predictions = self.forward(features)
76+
poses = predictions["trajectory"].squeeze(0).numpy()
77+
78+
# extract trajectory
79+
return Trajectory(poses)
80+
81+
def compute_loss(
82+
self,
83+
features: Dict[str, torch.Tensor],
84+
targets: Dict[str, torch.Tensor],
85+
predictions: Dict[str, torch.Tensor],
86+
) -> torch.Tensor:
87+
"""
88+
Computes the loss used for backpropagation based on the features, targets and model predictions.
89+
"""
90+
raise NotImplementedError("No loss. Agent does not support training.")
91+
92+
def get_optimizers(
93+
self
94+
) -> Union[
95+
torch.optim.Optimizer,
96+
Dict[str, Union[
97+
torch.optim.Optimizer,
98+
torch.optim.lr_scheduler.LRScheduler]
99+
]
100+
]:
101+
"""
102+
Returns the optimizers that are used by thy pytorch-lightning trainer.
103+
Has to be either a single optimizer or a dict of optimizer and lr scheduler.
104+
"""
105+
raise NotImplementedError("No optimizers. Agent does not support training.")

navsim/agents/ego_status_mlp_agent.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from typing import Any, List, Dict
2+
from torch.optim import Optimizer
3+
from torch.optim.lr_scheduler import LRScheduler
4+
5+
from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling
6+
7+
from navsim.agents.abstract_agent import AbstractAgent
8+
from navsim.common.dataclasses import AgentInput, SensorConfig
9+
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
10+
from navsim.common.dataclasses import Scene
11+
12+
13+
import torch
14+
15+
16+
class EgoStatusFeatureBuilder(AbstractFeatureBuilder):
17+
def __init__(self):
18+
pass
19+
20+
def compute_features(self, agent_input: AgentInput) -> Dict[str, torch.Tensor]:
21+
ego_status = agent_input.ego_statuses[-1]
22+
velocity = torch.tensor(ego_status.ego_velocity)
23+
acceleration = torch.tensor(ego_status.ego_acceleration)
24+
driving_command = torch.tensor(ego_status.driving_command)
25+
ego_state_feature = torch.cat([velocity, acceleration, driving_command], dim=-1)
26+
27+
return {"ego_state": ego_state_feature}
28+
29+
30+
class TrajectoryTargetBuilder(AbstractTargetBuilder):
31+
def __init__(self, trajectory_sampling: TrajectorySampling):
32+
self._trajectory_sampling = trajectory_sampling
33+
34+
def compute_targets(self, scene: Scene) -> Dict[str, torch.Tensor]:
35+
future_trajectory = scene.get_future_trajectory(
36+
num_trajectory_frames=self._trajectory_sampling.num_poses
37+
)
38+
return {"trajectory": torch.tensor(future_trajectory.poses)}
39+
40+
41+
class EgoStatusMLPAgent(AbstractAgent):
42+
def __init__(
43+
self,
44+
trajectory_sampling: TrajectorySampling,
45+
hidden_layer_dim: int,
46+
lr: float,
47+
checkpoint_path: str = None,
48+
):
49+
super().__init__()
50+
self._trajectory_sampling = trajectory_sampling
51+
self._checkpoint_path = checkpoint_path
52+
53+
self._lr = lr
54+
55+
self._mlp = torch.nn.Sequential(
56+
torch.nn.Linear(8, hidden_layer_dim),
57+
torch.nn.ReLU(),
58+
torch.nn.Linear(hidden_layer_dim, hidden_layer_dim),
59+
torch.nn.ReLU(),
60+
torch.nn.Linear(hidden_layer_dim, hidden_layer_dim),
61+
torch.nn.ReLU(),
62+
torch.nn.Linear(hidden_layer_dim, self._trajectory_sampling.num_poses * 3),
63+
)
64+
65+
def name(self) -> str:
66+
"""Inherited, see superclass."""
67+
68+
return self.__class__.__name__
69+
70+
def initialize(self) -> None:
71+
"""Inherited, see superclass."""
72+
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"]
73+
self.load_state_dict({k.replace("agent.",""):v for k,v in state_dict.items()})
74+
75+
def get_sensor_config(self) -> SensorConfig:
76+
"""Inherited, see superclass."""
77+
return SensorConfig.build_no_sensors()
78+
79+
def get_target_builders(self) -> List[AbstractTargetBuilder]:
80+
return [
81+
TrajectoryTargetBuilder(
82+
trajectory_sampling=self._trajectory_sampling
83+
),
84+
]
85+
86+
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
87+
return [EgoStatusFeatureBuilder()]
88+
89+
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
90+
poses: torch.Tensor = self._mlp(features["ego_state"])
91+
return {"trajectory": poses.reshape(-1, self._trajectory_sampling.num_poses, 3)}
92+
93+
def compute_loss(
94+
self,
95+
features: Dict[str, torch.Tensor],
96+
targets: Dict[str, torch.Tensor],
97+
predictions: Dict[str, torch.Tensor],
98+
) -> torch.Tensor:
99+
return torch.nn.functional.l1_loss(predictions["trajectory"], targets["trajectory"])
100+
101+
def get_optimizers(self) -> Optimizer | Dict[str, Optimizer | LRScheduler]:
102+
return torch.optim.Adam(self._mlp.parameters(), lr=self._lr)

navsim/common/dataclasses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def __post_init__(self):
212212
class Trajectory:
213213
poses: npt.NDArray[np.float32] # local coordinates
214214
trajectory_sampling: TrajectorySampling = TrajectorySampling(
215-
time_horizon=5, interval_length=0.5
215+
time_horizon=4, interval_length=0.5
216216
)
217217

218218
def __post_init__(self):

0 commit comments

Comments
 (0)