@@ -121,7 +121,7 @@ def __init__(self, mode="base", jit=False, device=None, ckpt=None, fast=None):
121
121
]
122
122
)
123
123
124
- self .background = None
124
+ self .background = { 'img' : None , 'name' : None , 'shape' : None }
125
125
desc = "Mode={}, Device={}, Torchscript={}" .format (
126
126
mode , self .device , "enabled" if jit else "disabled"
127
127
)
@@ -198,10 +198,19 @@ def process(self, img, type="rgba", threshold=None):
198
198
img [border != 0 ] = [120 , 255 , 155 ]
199
199
200
200
elif type .lower ().endswith ((".jpg" , ".jpeg" , ".png" )):
201
- if self .background is None :
202
- self .background = cv2 .cvtColor (cv2 .imread (type ), cv2 .COLOR_BGR2RGB )
203
- self .background = cv2 .resize (self .background , img .shape [:2 ][::- 1 ])
204
- img = img * pred [..., np .newaxis ] + self .background * (
201
+ if self .background ['name' ] != type :
202
+ background_img = cv2 .cvtColor (cv2 .imread (type ), cv2 .COLOR_BGR2RGB )
203
+ background_img = cv2 .resize (background_img , img .shape [:2 ][::- 1 ])
204
+
205
+ self .background ['img' ] = background_img
206
+ self .background ['shape' ] = img .shape [:2 ][::- 1 ]
207
+ self .background ['name' ] = type
208
+
209
+ elif self .background ['shape' ] != img .shape [:2 ][::- 1 ]:
210
+ self .background ['img' ] = cv2 .resize (self .background ['img' ], img .shape [:2 ][::- 1 ])
211
+ self .background ['shape' ] = img .shape [:2 ][::- 1 ]
212
+
213
+ img = img * pred [..., np .newaxis ] + self .background ['img' ] * (
205
214
1 - pred [..., np .newaxis ]
206
215
)
207
216
0 commit comments