@@ -451,54 +451,45 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
451
451
torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
452
452
The output maintains the same dtype as the input.
453
453
"""
454
+ original_shape = parallelogram .shape
454
455
dtype = parallelogram .dtype
455
456
acceptable_dtypes = [torch .float32 , torch .float64 ]
456
457
need_cast = dtype not in acceptable_dtypes
457
458
if need_cast :
458
459
# Up-case to avoid overflow for square operations
459
460
parallelogram = parallelogram .to (torch .float32 )
460
- out_boxes = parallelogram .clone ()
461
-
462
- # Calculate parallelogram diagonal vectors
463
- dx13 = parallelogram [..., 4 ] - parallelogram [..., 0 ]
464
- dy13 = parallelogram [..., 5 ] - parallelogram [..., 1 ]
465
- dx42 = parallelogram [..., 2 ] - parallelogram [..., 6 ]
466
- dy42 = parallelogram [..., 3 ] - parallelogram [..., 7 ]
467
- dx12 = parallelogram [..., 2 ] - parallelogram [..., 0 ]
468
- dy12 = parallelogram [..., 1 ] - parallelogram [..., 3 ]
469
- diag13 = torch .sqrt (dx13 ** 2 + dy13 ** 2 )
470
- diag24 = torch .sqrt (dx42 ** 2 + dy42 ** 2 )
471
- mask = diag13 > diag24
472
-
473
- # Calculate rotation angle in radians
474
- r_rad = torch .atan2 (dy12 , dx12 )
475
- cos , sin = torch .cos (r_rad ), torch .sin (r_rad )
476
-
477
- # Calculate width using the angle between diagonal and rotation
478
- w = torch .where (
479
- mask ,
480
- diag13 * torch .abs (torch .sin (torch .atan2 (dx13 , dy13 ) - r_rad )),
481
- diag24 * torch .abs (torch .sin (torch .atan2 (dx42 , dy42 ) - r_rad )),
482
- )
483
461
484
- delta_x = w * cos
485
- delta_y = w * sin
486
- # Update coordinates to form a rectangle
487
- # Keeping the points (x1, y1) and (x3, y3) unchanged.
488
- out_boxes [..., 2 ] = torch .where (mask , parallelogram [..., 0 ] + delta_x , parallelogram [..., 2 ])
489
- out_boxes [..., 3 ] = torch .where (mask , parallelogram [..., 1 ] - delta_y , parallelogram [..., 3 ])
490
- out_boxes [..., 6 ] = torch .where (mask , parallelogram [..., 4 ] - delta_x , parallelogram [..., 6 ])
491
- out_boxes [..., 7 ] = torch .where (mask , parallelogram [..., 5 ] + delta_y , parallelogram [..., 7 ])
492
-
493
- # Keeping the points (x2, y2) and (x4, y4) unchanged.
494
- out_boxes [..., 0 ] = torch .where (~ mask , parallelogram [..., 2 ] - delta_x , parallelogram [..., 0 ])
495
- out_boxes [..., 1 ] = torch .where (~ mask , parallelogram [..., 3 ] + delta_y , parallelogram [..., 1 ])
496
- out_boxes [..., 4 ] = torch .where (~ mask , parallelogram [..., 6 ] + delta_x , parallelogram [..., 4 ])
497
- out_boxes [..., 5 ] = torch .where (~ mask , parallelogram [..., 7 ] - delta_y , parallelogram [..., 5 ])
462
+ x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4 = parallelogram .unbind (- 1 )
463
+ cx = (x1 + x3 ) / 2
464
+ cy = (y1 + y3 ) / 2
465
+
466
+ # Calculate width, height, and rotation angle of the parallelogram
467
+ wp = torch .sqrt ((x2 - x1 ) ** 2 + (y2 - y1 ) ** 2 )
468
+ hp = torch .sqrt ((x4 - x1 ) ** 2 + (y4 - y1 ) ** 2 )
469
+ r12 = torch .atan2 (y1 - y2 , x2 - x1 )
470
+ r14 = torch .atan2 (y1 - y4 , x4 - x1 )
471
+ r_rad = r12 - r14
472
+ sign = torch .where (r_rad > torch .pi / 2 , - 1 , 1 )
473
+ cos , sin = r_rad .cos (), r_rad .sin ()
474
+
475
+ # Calculate width, height, and rotation angle of the rectangle
476
+ w = torch .where (wp < hp , wp * sin , wp + hp * cos * sign )
477
+ h = torch .where (wp > hp , hp * sin , hp + wp * cos * sign )
478
+ r_rad = torch .where (hp > wp , r14 + torch .pi / 2 , r12 )
479
+ cos , sin = r_rad .cos (), r_rad .sin ()
480
+
481
+ x1 = cx - w / 2 * cos - h / 2 * sin
482
+ y1 = cy - h / 2 * cos + w / 2 * sin
483
+ x2 = cx + w / 2 * cos - h / 2 * sin
484
+ y2 = cy - h / 2 * cos - w / 2 * sin
485
+ x3 = cx + w / 2 * cos + h / 2 * sin
486
+ y3 = cy + h / 2 * cos - w / 2 * sin
487
+ x4 = cx - w / 2 * cos + h / 2 * sin
488
+ y4 = cy + h / 2 * cos + w / 2 * sin
489
+ out_boxes = torch .stack ((x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4 ), dim = - 1 ).reshape (original_shape )
498
490
499
491
if need_cast :
500
492
out_boxes = out_boxes .to (dtype )
501
-
502
493
return out_boxes
503
494
504
495
0 commit comments