Skip to content

Commit 7a936da

Browse files
Fix _parallelogram_to_bounding_boxes (#9181)
1 parent 64666c7 commit 7a936da

File tree

2 files changed

+52
-38
lines changed

2 files changed

+52
-38
lines changed

test/test_transforms_v2.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7016,6 +7016,29 @@ def test_parallelogram_to_bounding_boxes(input_size, device):
70167016
actual = _parallelogram_to_bounding_boxes(parallelogram)
70177017
torch.testing.assert_close(actual, expected)
70187018

7019+
# Test the transformation of a simple parallelogram.
7020+
# 1
7021+
# 1-2 / 2
7022+
# / / -> / /
7023+
# 4-3 4 /
7024+
# 3
7025+
#
7026+
# 1
7027+
# 1-2 \ 2
7028+
# \ \ -> \ \
7029+
# 4-3 4 \
7030+
# 3
7031+
parallelogram = torch.tensor(
7032+
[[0, 4, 3, 1, 5, 1, 2, 4], [0, 1, 2, 1, 5, 4, 3, 4]],
7033+
dtype=torch.float32,
7034+
)
7035+
expected = torch.tensor(
7036+
[[0, 4, 4, 0, 5, 1, 1, 5], [0, 1, 1, 0, 5, 4, 4, 5]],
7037+
dtype=torch.float32,
7038+
)
7039+
actual = _parallelogram_to_bounding_boxes(parallelogram)
7040+
torch.testing.assert_close(actual, expected)
7041+
70197042

70207043
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
70217044
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -451,54 +451,45 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
451451
torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
452452
The output maintains the same dtype as the input.
453453
"""
454+
original_shape = parallelogram.shape
454455
dtype = parallelogram.dtype
455456
acceptable_dtypes = [torch.float32, torch.float64]
456457
need_cast = dtype not in acceptable_dtypes
457458
if need_cast:
458459
# Up-case to avoid overflow for square operations
459460
parallelogram = parallelogram.to(torch.float32)
460-
out_boxes = parallelogram.clone()
461-
462-
# Calculate parallelogram diagonal vectors
463-
dx13 = parallelogram[..., 4] - parallelogram[..., 0]
464-
dy13 = parallelogram[..., 5] - parallelogram[..., 1]
465-
dx42 = parallelogram[..., 2] - parallelogram[..., 6]
466-
dy42 = parallelogram[..., 3] - parallelogram[..., 7]
467-
dx12 = parallelogram[..., 2] - parallelogram[..., 0]
468-
dy12 = parallelogram[..., 1] - parallelogram[..., 3]
469-
diag13 = torch.sqrt(dx13**2 + dy13**2)
470-
diag24 = torch.sqrt(dx42**2 + dy42**2)
471-
mask = diag13 > diag24
472-
473-
# Calculate rotation angle in radians
474-
r_rad = torch.atan2(dy12, dx12)
475-
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
476-
477-
# Calculate width using the angle between diagonal and rotation
478-
w = torch.where(
479-
mask,
480-
diag13 * torch.abs(torch.sin(torch.atan2(dx13, dy13) - r_rad)),
481-
diag24 * torch.abs(torch.sin(torch.atan2(dx42, dy42) - r_rad)),
482-
)
483461

484-
delta_x = w * cos
485-
delta_y = w * sin
486-
# Update coordinates to form a rectangle
487-
# Keeping the points (x1, y1) and (x3, y3) unchanged.
488-
out_boxes[..., 2] = torch.where(mask, parallelogram[..., 0] + delta_x, parallelogram[..., 2])
489-
out_boxes[..., 3] = torch.where(mask, parallelogram[..., 1] - delta_y, parallelogram[..., 3])
490-
out_boxes[..., 6] = torch.where(mask, parallelogram[..., 4] - delta_x, parallelogram[..., 6])
491-
out_boxes[..., 7] = torch.where(mask, parallelogram[..., 5] + delta_y, parallelogram[..., 7])
492-
493-
# Keeping the points (x2, y2) and (x4, y4) unchanged.
494-
out_boxes[..., 0] = torch.where(~mask, parallelogram[..., 2] - delta_x, parallelogram[..., 0])
495-
out_boxes[..., 1] = torch.where(~mask, parallelogram[..., 3] + delta_y, parallelogram[..., 1])
496-
out_boxes[..., 4] = torch.where(~mask, parallelogram[..., 6] + delta_x, parallelogram[..., 4])
497-
out_boxes[..., 5] = torch.where(~mask, parallelogram[..., 7] - delta_y, parallelogram[..., 5])
462+
x1, y1, x2, y2, x3, y3, x4, y4 = parallelogram.unbind(-1)
463+
cx = (x1 + x3) / 2
464+
cy = (y1 + y3) / 2
465+
466+
# Calculate width, height, and rotation angle of the parallelogram
467+
wp = torch.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
468+
hp = torch.sqrt((x4 - x1) ** 2 + (y4 - y1) ** 2)
469+
r12 = torch.atan2(y1 - y2, x2 - x1)
470+
r14 = torch.atan2(y1 - y4, x4 - x1)
471+
r_rad = r12 - r14
472+
sign = torch.where(r_rad > torch.pi / 2, -1, 1)
473+
cos, sin = r_rad.cos(), r_rad.sin()
474+
475+
# Calculate width, height, and rotation angle of the rectangle
476+
w = torch.where(wp < hp, wp * sin, wp + hp * cos * sign)
477+
h = torch.where(wp > hp, hp * sin, hp + wp * cos * sign)
478+
r_rad = torch.where(hp > wp, r14 + torch.pi / 2, r12)
479+
cos, sin = r_rad.cos(), r_rad.sin()
480+
481+
x1 = cx - w / 2 * cos - h / 2 * sin
482+
y1 = cy - h / 2 * cos + w / 2 * sin
483+
x2 = cx + w / 2 * cos - h / 2 * sin
484+
y2 = cy - h / 2 * cos - w / 2 * sin
485+
x3 = cx + w / 2 * cos + h / 2 * sin
486+
y3 = cy + h / 2 * cos - w / 2 * sin
487+
x4 = cx - w / 2 * cos + h / 2 * sin
488+
y4 = cy + h / 2 * cos + w / 2 * sin
489+
out_boxes = torch.stack((x1, y1, x2, y2, x3, y3, x4, y4), dim=-1).reshape(original_shape)
498490

499491
if need_cast:
500492
out_boxes = out_boxes.to(dtype)
501-
502493
return out_boxes
503494

504495

0 commit comments

Comments
 (0)