Skip to content

Commit 583e907

Browse files
[Feature] Support training of BEVFusion (#2558)
* support train on nus * refactor transfusion head * img branch optioinal * support nuscenes_mini in replace_ceph_backend * use replace_ceph * add only-lidar * use valid_flag in dataset filter * support lidar-only training 69 * fix RTS * fix rotation in ImgAug3D * revert to original rotation in ImgAug3D * add LSSDepthTransform and parse_losses * fix LoadMultiSweeps * fix bug about points in-place operations * support amp and replace syncBN by BN * add amp config * set growth-interval in amp * Revert "fix LoadMultiSweeps" This reverts commit ab27ea1. * add float in cls loss * iter_based lr in fusion stage * rename config * use normalization query pos for stable training * remove unnecessary code & simplify config & train 5 epoch * smaller ete_min_ratio * polish code * fix UT * Revert "use normalization query pos for stable training" This reverts commit 3009118. * update readme * fix height offset
1 parent ed46b8c commit 583e907

11 files changed

+693
-249
lines changed

mmdet3d/engine/hooks/disable_object_sample_hook.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmengine.dataset import BaseDataset
23
from mmengine.hooks import Hook
34
from mmengine.model import is_model_wrapper
45
from mmengine.runner import Runner
@@ -35,7 +36,11 @@ def before_train_epoch(self, runner: Runner):
3536
model = model.module
3637
if epoch == self.disable_after_epoch:
3738
runner.logger.info('Disable ObjectSample')
38-
for transform in runner.train_dataloader.dataset.pipeline.transforms: # noqa: E501
39+
dataset = runner.train_dataloader.dataset
40+
# handle dataset wrapper
41+
if not isinstance(dataset, BaseDataset):
42+
dataset = dataset.dataset
43+
for transform in dataset.pipeline.transforms: # noqa: E501
3944
if isinstance(transform, ObjectSample):
4045
assert hasattr(transform, 'disabled')
4146
transform.disabled = True

projects/BEVFusion/README.md

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ results is available at https://github.com/mit-han-lab/bevfusion.
1515

1616
## Introduction
1717

18-
We implement BEVFusion and provide the results and pretrained checkpoints on NuScenes dataset.
18+
We implement BEVFusion and support training and testing on NuScenes dataset.
1919

2020
## Usage
2121

@@ -34,38 +34,41 @@ python projects/BEVFusion/setup.py develop
3434
Run a demo on NuScenes data using [BEVFusion model](https://drive.google.com/file/d/1QkvbYDk4G2d6SZoeJqish13qSyXA4lp3/view?usp=share_link):
3535

3636
```shell
37-
python demo/multi_modality_demo.py demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__LIDAR_TOP__1532402927647951.pcd.bin demo/data/nuscenes/ demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_FILE} --cam-type all --score-thr 0.2 --show
37+
python demo/multi_modality_demo.py demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__LIDAR_TOP__1532402927647951.pcd.bin demo/data/nuscenes/ demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_FILE} --cam-type all --score-thr 0.2 --show
3838
```
3939

4040
### Training commands
4141

42-
In MMDetection3D's root directory, run the following command to train the model:
42+
1. You should train the lidar-only detector first:
4343

4444
```bash
45-
python tools/train.py projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py
45+
bash tools/dist_train.py projects/BEVFusion/configs/bevfusion_lidar_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py 8
4646
```
4747

48-
For multi-gpu training, run:
48+
2. Download the [Swin pre-trained model](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/bevfusion/swint-nuimages-pretrained.pth). Given the image pre-trained backbone and the lidar-only pre-trained detector, you could train the lidar-camera fusion model:
4949

5050
```bash
51-
python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py
51+
bash tools/dist_train.sh projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py 8 --cfg-options load_from=${LIDAR_PRETRAINED_CHECKPOINT} model.img_backbone.init_cfg.checkpoint=${IMAGE_PRETRAINED_BACKBONE}
5252
```
5353

54+
**Note** that if you want to reduce CUDA memory usage and computational overhead, you could directly add `--amp` on the tail of the above commands. The model under this setting will be trained in fp16 mode.
55+
5456
### Testing commands
5557

5658
In MMDetection3D's root directory, run the following command to test the model:
5759

5860
```bash
59-
python tools/test.py projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_PATH}
61+
bash tools/dist_test.sh projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_PATH} 8
6062
```
6163

6264
## Results and models
6365

6466
### NuScenes
6567

66-
| Backbone | Voxel type (voxel size) | NMS | Mem (GB) | Inf time (fps) | NDS | mAP | Download |
67-
| :-----------------------------------------------------------------------------: | :---------------------: | :-: | :------: | :------------: | :---: | :---: | :------------------------------------------------------------------------------------------------------: |
68-
| [SECFPN](./configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py) | voxel (0.075) | × | - | - | 71.62 | 68.77 | [converted_model](https://drive.google.com/file/d/1QkvbYDk4G2d6SZoeJqish13qSyXA4lp3/view?usp=share_link) |
68+
| Modality | Voxel type (voxel size) | NMS | Mem (GB) | Inf time (fps) | NDS | mAP | Download |
69+
| :------------------------------------------------------------------------------------------: | :---------------------: | :-: | :------: | :------------: | :--: | :--: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
70+
| [lidar](./configs/bevfusion_lidar_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py) | voxel (0.075) | × | - | - | 69.6 | 64.9 | [model](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/bevfusion/bevfusion_lidar_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d-2628f933.pth) [logs](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/bevfusion/bevfusion_lidar_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d_20230322_053447.log) |
71+
| [lidar-cam](./configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py) | voxel (0.075) | × | - | - | 71.4 | 68.6 | [model](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/bevfusion/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d-5239b1af.pth) [logs](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/bevfusion/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d_20230524_001539.log) |
6972

7073
## Citation
7174

@@ -103,9 +106,9 @@ A project does not necessarily have to be finished in a single PR, but it's esse
103106

104107
<!-- As this template does. -->
105108

106-
- [ ] Milestone 2: Indicates a successful model implementation.
109+
- [x] Milestone 2: Indicates a successful model implementation.
107110

108-
- [ ] Training-time correctness
111+
- [x] Training-time correctness
109112

110113
<!-- If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range. -->
111114

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
from .bevfusion import BEVFusion
22
from .bevfusion_necks import GeneralizedLSSFPN
3-
from .depth_lss import DepthLSSTransform
3+
from .depth_lss import DepthLSSTransform, LSSTransform
44
from .loading import BEVLoadMultiViewImageFromFiles
55
from .sparse_encoder import BEVFusionSparseEncoder
66
from .transformer import TransformerDecoderLayer
7-
from .transforms_3d import GridMask, ImageAug3D
7+
from .transforms_3d import (BEVFusionGlobalRotScaleTrans,
8+
BEVFusionRandomFlip3D, GridMask, ImageAug3D)
89
from .transfusion_head import ConvFuser, TransFusionHead
910
from .utils import (BBoxBEVL1Cost, HeuristicAssigner3D, HungarianAssigner3D,
1011
IoU3DCost)
1112

1213
__all__ = [
1314
'BEVFusion', 'TransFusionHead', 'ConvFuser', 'ImageAug3D', 'GridMask',
1415
'GeneralizedLSSFPN', 'HungarianAssigner3D', 'BBoxBEVL1Cost', 'IoU3DCost',
15-
'HeuristicAssigner3D', 'DepthLSSTransform',
16+
'HeuristicAssigner3D', 'DepthLSSTransform', 'LSSTransform',
1617
'BEVLoadMultiViewImageFromFiles', 'BEVFusionSparseEncoder',
17-
'TransformerDecoderLayer'
18+
'TransformerDecoderLayer', 'BEVFusionRandomFlip3D',
19+
'BEVFusionGlobalRotScaleTrans'
1820
]

projects/BEVFusion/bevfusion/bevfusion.py

Lines changed: 105 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from typing import Dict, List, Optional
1+
from collections import OrderedDict
2+
from copy import deepcopy
3+
from typing import Dict, List, Optional, Tuple
24

35
import numpy as np
46
import torch
7+
import torch.distributed as dist
8+
from mmengine.utils import is_list_of
59
from torch import Tensor
610
from torch.nn import functional as F
711

@@ -23,7 +27,7 @@ def __init__(
2327
fusion_layer: Optional[dict] = None,
2428
img_backbone: Optional[dict] = None,
2529
pts_backbone: Optional[dict] = None,
26-
vtransform: Optional[dict] = None,
30+
view_transform: Optional[dict] = None,
2731
img_neck: Optional[dict] = None,
2832
pts_neck: Optional[dict] = None,
2933
bbox_head: Optional[dict] = None,
@@ -40,20 +44,21 @@ def __init__(
4044

4145
self.pts_voxel_encoder = MODELS.build(pts_voxel_encoder)
4246

43-
self.img_backbone = MODELS.build(img_backbone)
44-
self.img_neck = MODELS.build(img_neck)
45-
self.vtransform = MODELS.build(vtransform)
47+
self.img_backbone = MODELS.build(
48+
img_backbone) if img_backbone is not None else None
49+
self.img_neck = MODELS.build(
50+
img_neck) if img_neck is not None else None
51+
self.view_transform = MODELS.build(
52+
view_transform) if view_transform is not None else None
4653
self.pts_middle_encoder = MODELS.build(pts_middle_encoder)
4754

48-
self.fusion_layer = MODELS.build(fusion_layer)
55+
self.fusion_layer = MODELS.build(
56+
fusion_layer) if fusion_layer is not None else None
4957

5058
self.pts_backbone = MODELS.build(pts_backbone)
5159
self.pts_neck = MODELS.build(pts_neck)
5260

5361
self.bbox_head = MODELS.build(bbox_head)
54-
# hard code here where using converted checkpoint of original
55-
# implementation of `BEVFusion`
56-
self.use_converted_checkpoint = True
5762

5863
self.init_weights()
5964

@@ -67,6 +72,46 @@ def _forward(self,
6772
"""
6873
pass
6974

75+
def parse_losses(
76+
self, losses: Dict[str, torch.Tensor]
77+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
78+
"""Parses the raw outputs (losses) of the network.
79+
80+
Args:
81+
losses (dict): Raw output of the network, which usually contain
82+
losses and other necessary information.
83+
84+
Returns:
85+
tuple[Tensor, dict]: There are two elements. The first is the
86+
loss tensor passed to optim_wrapper which may be a weighted sum
87+
of all losses, and the second is log_vars which will be sent to
88+
the logger.
89+
"""
90+
log_vars = []
91+
for loss_name, loss_value in losses.items():
92+
if isinstance(loss_value, torch.Tensor):
93+
log_vars.append([loss_name, loss_value.mean()])
94+
elif is_list_of(loss_value, torch.Tensor):
95+
log_vars.append(
96+
[loss_name,
97+
sum(_loss.mean() for _loss in loss_value)])
98+
else:
99+
raise TypeError(
100+
f'{loss_name} is not a tensor or list of tensors')
101+
102+
loss = sum(value for key, value in log_vars if 'loss' in key)
103+
log_vars.insert(0, ['loss', loss])
104+
log_vars = OrderedDict(log_vars) # type: ignore
105+
106+
for loss_name, loss_value in log_vars.items():
107+
# reduce loss when distributed training
108+
if dist.is_available() and dist.is_initialized():
109+
loss_value = loss_value.data.clone()
110+
dist.all_reduce(loss_value.div_(dist.get_world_size()))
111+
log_vars[loss_name] = loss_value.item()
112+
113+
return loss, log_vars # type: ignore
114+
70115
def init_weights(self) -> None:
71116
if self.img_backbone is not None:
72117
self.img_backbone.init_weights()
@@ -94,7 +139,7 @@ def extract_img_feat(
94139
img_metas,
95140
) -> torch.Tensor:
96141
B, N, C, H, W = x.size()
97-
x = x.view(B * N, C, H, W)
142+
x = x.view(B * N, C, H, W).contiguous()
98143

99144
x = self.img_backbone(x)
100145
x = self.img_neck(x)
@@ -105,22 +150,25 @@ def extract_img_feat(
105150
BN, C, H, W = x.size()
106151
x = x.view(B, int(BN / B), C, H, W)
107152

108-
x = self.vtransform(
109-
x,
110-
points,
111-
lidar2image,
112-
camera_intrinsics,
113-
camera2lidar,
114-
img_aug_matrix,
115-
lidar_aug_matrix,
116-
img_metas,
117-
)
153+
with torch.autocast(device_type='cuda', dtype=torch.float32):
154+
x = self.view_transform(
155+
x,
156+
points,
157+
lidar2image,
158+
camera_intrinsics,
159+
camera2lidar,
160+
img_aug_matrix,
161+
lidar_aug_matrix,
162+
img_metas,
163+
)
118164
return x
119165

120166
def extract_pts_feat(self, batch_inputs_dict) -> torch.Tensor:
121167
points = batch_inputs_dict['points']
122-
feats, coords, sizes = self.voxelize(points)
123-
batch_size = coords[-1, 0] + 1
168+
with torch.autocast('cuda', enabled=False):
169+
points = [point.float() for point in points]
170+
feats, coords, sizes = self.voxelize(points)
171+
batch_size = coords[-1, 0] + 1
124172
x = self.pts_middle_encoder(feats, coords, batch_size)
125173
return x
126174

@@ -184,11 +232,6 @@ def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
184232

185233
if self.with_bbox_head:
186234
outputs = self.bbox_head.predict(feats, batch_input_metas)
187-
if self.use_converted_checkpoint:
188-
outputs[0]['bboxes_3d'].tensor[:, 6] = -outputs[0][
189-
'bboxes_3d'].tensor[:, 6] - np.pi / 2
190-
outputs[0]['bboxes_3d'].tensor[:, 3:5] = outputs[0][
191-
'bboxes_3d'].tensor[:, [4, 3]]
192235

193236
res = self.add_pred_to_datasample(batch_data_samples, outputs)
194237

@@ -202,28 +245,32 @@ def extract_feat(
202245
):
203246
imgs = batch_inputs_dict.get('imgs', None)
204247
points = batch_inputs_dict.get('points', None)
205-
206-
lidar2image, camera_intrinsics, camera2lidar = [], [], []
207-
img_aug_matrix, lidar_aug_matrix = [], []
208-
for i, meta in enumerate(batch_input_metas):
209-
lidar2image.append(meta['lidar2img'])
210-
camera_intrinsics.append(meta['cam2img'])
211-
camera2lidar.append(meta['cam2lidar'])
212-
img_aug_matrix.append(meta.get('img_aug_matrix', np.eye(4)))
213-
lidar_aug_matrix.append(meta.get('lidar_aug_matrix', np.eye(4)))
214-
215-
lidar2image = imgs.new_tensor(np.asarray(lidar2image))
216-
camera_intrinsics = imgs.new_tensor(np.array(camera_intrinsics))
217-
camera2lidar = imgs.new_tensor(np.asarray(camera2lidar))
218-
img_aug_matrix = imgs.new_tensor(np.asarray(img_aug_matrix))
219-
lidar_aug_matrix = imgs.new_tensor(np.asarray(lidar_aug_matrix))
220-
img_feature = self.extract_img_feat(imgs, points, lidar2image,
221-
camera_intrinsics, camera2lidar,
222-
img_aug_matrix, lidar_aug_matrix,
223-
batch_input_metas)
248+
features = []
249+
if imgs is not None:
250+
imgs = imgs.contiguous()
251+
lidar2image, camera_intrinsics, camera2lidar = [], [], []
252+
img_aug_matrix, lidar_aug_matrix = [], []
253+
for i, meta in enumerate(batch_input_metas):
254+
lidar2image.append(meta['lidar2img'])
255+
camera_intrinsics.append(meta['cam2img'])
256+
camera2lidar.append(meta['cam2lidar'])
257+
img_aug_matrix.append(meta.get('img_aug_matrix', np.eye(4)))
258+
lidar_aug_matrix.append(
259+
meta.get('lidar_aug_matrix', np.eye(4)))
260+
261+
lidar2image = imgs.new_tensor(np.asarray(lidar2image))
262+
camera_intrinsics = imgs.new_tensor(np.array(camera_intrinsics))
263+
camera2lidar = imgs.new_tensor(np.asarray(camera2lidar))
264+
img_aug_matrix = imgs.new_tensor(np.asarray(img_aug_matrix))
265+
lidar_aug_matrix = imgs.new_tensor(np.asarray(lidar_aug_matrix))
266+
img_feature = self.extract_img_feat(imgs, deepcopy(points),
267+
lidar2image, camera_intrinsics,
268+
camera2lidar, img_aug_matrix,
269+
lidar_aug_matrix,
270+
batch_input_metas)
271+
features.append(img_feature)
224272
pts_feature = self.extract_pts_feat(batch_inputs_dict)
225-
226-
features = [img_feature, pts_feature]
273+
features.append(pts_feature)
227274

228275
if self.fusion_layer is not None:
229276
x = self.fusion_layer(features)
@@ -239,4 +286,13 @@ def extract_feat(
239286
def loss(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
240287
batch_data_samples: List[Det3DDataSample],
241288
**kwargs) -> List[Det3DDataSample]:
242-
pass
289+
batch_input_metas = [item.metainfo for item in batch_data_samples]
290+
feats = self.extract_feat(batch_inputs_dict, batch_input_metas)
291+
292+
losses = dict()
293+
if self.with_bbox_head:
294+
bbox_loss = self.bbox_head.loss(feats, batch_data_samples)
295+
296+
losses.update(bbox_loss)
297+
298+
return losses

0 commit comments

Comments
 (0)