Skip to content

Commit d24fbdf

Browse files
committed
Implement square resize
1 parent dd954dc commit d24fbdf

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
RandomCrop,
1212
RandomHorizontalFlip,
1313
RandomResizedCrop,
14+
SquareResize,
1415
)
1516
from transformations.transformation import Transformation
1617

@@ -49,6 +50,8 @@ def get_transformation(transformation: str) -> Transformation:
4950
return RandomCrop(
5051
scale=(0.4, 1.0), ratio=(0.75, 1.33), min_num_pixels=224 * 224
5152
)
53+
elif transformation == "square_resize":
54+
return SquareResize(224)
5255
else:
5356
raise NotImplementedError()
5457

transformations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from .random_crop import RandomCrop
44
from .random_horizontal_flip import RandomHorizontalFlip
55
from .random_resized_crop import RandomResizedCrop
6+
from .square_resize import SquareResize
67

78
__all__ = [
89
AutoAugment,
910
NativeAspectRatioResize,
1011
RandomCrop,
1112
RandomHorizontalFlip,
1213
RandomResizedCrop,
14+
SquareResize,
1315
]

transformations/square_resize.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import tensorflow as tf
2+
3+
from transformations.transformation import Transformation
4+
5+
6+
class SquareResize(Transformation):
7+
def __init__(self, size: int):
8+
self.size = tf.constant([size, size])
9+
10+
def __call__(self, image):
11+
return tf.image.resize(
12+
image, self.size, method=tf.image.ResizeMethod.BICUBIC, antialias=True
13+
)

0 commit comments

Comments
 (0)