23
23
from mmdet3d .registry import VISUALIZERS
24
24
from mmdet3d .structures import (BaseInstance3DBoxes , Box3DMode ,
25
25
CameraInstance3DBoxes , Coord3DMode ,
26
- DepthInstance3DBoxes , Det3DDataSample ,
27
- LiDARInstance3DBoxes , PointData ,
28
- points_cam2img )
26
+ DepthInstance3DBoxes , DepthPoints ,
27
+ Det3DDataSample , LiDARInstance3DBoxes ,
28
+ PointData , points_cam2img )
29
29
from .vis_utils import (proj_camera_bbox3d_to_img , proj_depth_bbox3d_to_img ,
30
30
proj_lidar_bbox3d_to_img , to_depth_mode )
31
31
@@ -293,7 +293,7 @@ def draw_bboxes_3d(self,
293
293
# convert bboxes to numpy dtype
294
294
bboxes_3d = tensor2ndarray (bboxes_3d .tensor )
295
295
296
- in_box_color = np .array (points_in_box_color )
296
+ # in_box_color = np.array(points_in_box_color)
297
297
298
298
for i in range (len (bboxes_3d )):
299
299
center = bboxes_3d [i , 0 :3 ]
@@ -320,7 +320,7 @@ def draw_bboxes_3d(self,
320
320
if self .pcd is not None and mode == 'xyz' :
321
321
indices = box3d .get_point_indices_within_bounding_box (
322
322
self .pcd .points )
323
- self .points_colors [indices ] = in_box_color
323
+ self .points_colors [indices ] = np . array ( bbox_color [ i ]) / 255.
324
324
325
325
# update points colors
326
326
if self .pcd is not None :
@@ -606,6 +606,7 @@ def _draw_instances_3d(self,
606
606
instances : InstanceData ,
607
607
input_meta : dict ,
608
608
vis_task : str ,
609
+ show_pcd_rgb : bool = False ,
609
610
palette : Optional [List [tuple ]] = None ) -> dict :
610
611
"""Draw 3D instances of GT or prediction.
611
612
@@ -616,6 +617,7 @@ def _draw_instances_3d(self,
616
617
input_meta (dict): Meta information.
617
618
vis_task (str): Visualization task, it includes: 'lidar_det',
618
619
'multi-modality_det', 'mono_det'.
620
+ show_pcd_rgb (bool): Whether to show RGB point cloud.
619
621
palette (List[tuple], optional): Palette information corresponding
620
622
to the category. Defaults to None.
621
623
@@ -643,13 +645,22 @@ def _draw_instances_3d(self,
643
645
else :
644
646
bboxes_3d_depth = bboxes_3d .clone ()
645
647
648
+ if 'axis_align_matrix' in input_meta :
649
+ points = DepthPoints (points , points_dim = points .shape [1 ])
650
+ rot_mat = input_meta ['axis_align_matrix' ][:3 , :3 ]
651
+ trans_vec = input_meta ['axis_align_matrix' ][:3 , - 1 ]
652
+ points .rotate (rot_mat .T )
653
+ points .translate (trans_vec )
654
+ points = tensor2ndarray (points .tensor )
655
+
646
656
max_label = int (max (labels_3d ) if len (labels_3d ) > 0 else 0 )
647
657
bbox_color = palette if self .bbox_color is None \
648
658
else self .bbox_color
649
659
bbox_palette = get_palette (bbox_color , max_label + 1 )
650
660
colors = [bbox_palette [label ] for label in labels_3d ]
651
661
652
- self .set_points (points , pcd_mode = 2 )
662
+ self .set_points (
663
+ points , pcd_mode = 2 , mode = 'xyzrgb' if show_pcd_rgb else 'xyz' )
653
664
self .draw_bboxes_3d (bboxes_3d_depth , bbox_color = colors )
654
665
655
666
data_3d ['bboxes_3d' ] = tensor2ndarray (bboxes_3d_depth .tensor )
@@ -871,7 +882,7 @@ def show(self,
871
882
self .o3d_vis .clear_geometries ()
872
883
try :
873
884
del self .pcd
874
- except KeyError :
885
+ except ( KeyError , AttributeError ) :
875
886
pass
876
887
if save_path is not None :
877
888
if not (save_path .endswith ('.png' )
@@ -923,7 +934,8 @@ def add_datasample(self,
923
934
o3d_save_path : Optional [str ] = None ,
924
935
vis_task : str = 'mono_det' ,
925
936
pred_score_thr : float = 0.3 ,
926
- step : int = 0 ) -> None :
937
+ step : int = 0 ,
938
+ show_pcd_rgb : bool = False ) -> None :
927
939
"""Draw datasample and save to all backends.
928
940
929
941
- If GT and prediction are plotted at the same time, they are displayed
@@ -954,6 +966,8 @@ def add_datasample(self,
954
966
pred_score_thr (float): The threshold to visualize the bboxes
955
967
and masks. Defaults to 0.3.
956
968
step (int): Global step value to record. Defaults to 0.
969
+ show_pcd_rgb (bool): Whether to show RGB point cloud. Defaults to
970
+ False.
957
971
"""
958
972
assert vis_task in (
959
973
'mono_det' , 'multi-view_det' , 'lidar_det' , 'lidar_seg' ,
@@ -976,7 +990,7 @@ def add_datasample(self,
976
990
if 'gt_instances_3d' in data_sample :
977
991
gt_data_3d = self ._draw_instances_3d (
978
992
data_input , data_sample .gt_instances_3d ,
979
- data_sample .metainfo , vis_task , palette )
993
+ data_sample .metainfo , vis_task , show_pcd_rgb , palette )
980
994
if 'gt_instances' in data_sample :
981
995
if len (data_sample .gt_instances ) > 0 :
982
996
assert 'img' in data_input
@@ -1006,7 +1020,8 @@ def add_datasample(self,
1006
1020
pred_data_3d = self ._draw_instances_3d (data_input ,
1007
1021
pred_instances_3d ,
1008
1022
data_sample .metainfo ,
1009
- vis_task , palette )
1023
+ vis_task , show_pcd_rgb ,
1024
+ palette )
1010
1025
if 'pred_instances' in data_sample :
1011
1026
if 'img' in data_input and len (data_sample .pred_instances ) > 0 :
1012
1027
pred_instances = data_sample .pred_instances
0 commit comments