Skip to content

Commit ad374f6

Browse files
committed
Add AIM inference
1 parent b536003 commit ad374f6

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

transformations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .aim_inference import AIMInference
12
from .auto_augment import AutoAugment
23
from .native_aspect_ratio_resize import NativeAspectRatioResize
34
from .random_crop import RandomCrop
@@ -6,6 +7,7 @@
67
from .square_resize import SquareResize
78

89
__all__ = [
10+
AIMInference,
911
AutoAugment,
1012
NativeAspectRatioResize,
1113
RandomCrop,

transformations/aim_inference.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import tensorflow as tf
2+
3+
from transformations.transformation import Transformation
4+
5+
6+
class AIMInference(Transformation):
7+
def __init__(self, resize_size, crop_size):
8+
self.resize_size = resize_size
9+
self.crop_size = crop_size
10+
self.offset = (resize_size - crop_size) // 2
11+
12+
def __call__(self, image):
13+
height = tf.cast(tf.shape(image)[0], tf.float32)
14+
width = tf.cast(tf.shape(image)[1], tf.float32)
15+
16+
if height < width:
17+
factor = self.resize_size / height
18+
else:
19+
factor = self.resize_size / width
20+
21+
new_height = tf.cast(height * factor, tf.int32)
22+
new_width = tf.cast(width * factor, tf.int32)
23+
24+
resized_image = tf.image.resize(
25+
image,
26+
[new_height, new_width],
27+
method=tf.image.ResizeMethod.BICUBIC,
28+
antialias=True
29+
)
30+
31+
cropped_image = tf.image.crop_to_bounding_box(
32+
resized_image,
33+
self.offset, self.offset,
34+
224, 224
35+
)
36+
return cropped_image

0 commit comments

Comments
 (0)