|
17 | 17 | import omni.isaac.lab.utils.math as math_utils
|
18 | 18 | from omni.isaac.lab.assets import Articulation, RigidObject
|
19 | 19 | from omni.isaac.lab.managers import SceneEntityCfg
|
| 20 | +from omni.isaac.lab.managers.manager_base import ManagerTermBase |
| 21 | +from omni.isaac.lab.managers.manager_term_cfg import ObservationTermCfg |
20 | 22 | from omni.isaac.lab.sensors import Camera, RayCaster, RayCasterCamera, TiledCamera
|
21 | 23 |
|
22 | 24 | if TYPE_CHECKING:
|
@@ -233,61 +235,69 @@ def image(
|
233 | 235 | return images.clone()
|
234 | 236 |
|
235 | 237 |
|
236 |
| -def image_features( |
237 |
| - env: ManagerBasedEnv, |
238 |
| - sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"), |
239 |
| - data_type: str = "rgb", |
240 |
| - convert_perspective_to_orthogonal: bool = True, |
241 |
| - model_name: str = "Theia", |
242 |
| - model_zoo_cfg: dict | None = None, |
243 |
| -) -> torch.Tensor: |
244 |
| - """Extracted image features with a frozen encoder from Images of a specific datatype from the camera sensor. |
245 |
| -
|
246 |
| - Args: |
247 |
| - env: The environment the cameras are placed within. |
248 |
| - sensor_cfg: The desired sensor to read from. Defaults to SceneEntityCfg("tiled_camera"). |
249 |
| - data_type: The data type to pull from the desired camera. Defaults to "rgb". |
250 |
| - model_name: The name of which model to use from the model_zoo_cfg to use to extract features. |
251 |
| - model_zoo_cfg: A dictionary with string keys and callable values. Should include "model", |
252 |
| - (mapped to a callable with no arguments to return the model), "preprocess" (mapped to |
253 |
| - a callable which consumes the images and returns the preprocessed images), |
254 |
| - and "inference" (mapped to a callable that provided the model, and the preproccessed images, |
255 |
| - returns the features.) |
| 238 | +class image_features(ManagerTermBase): |
| 239 | + """Extracted image features with a frozen encoder from images of a specific datatype from the camera sensor. |
256 | 240 |
|
257 |
| - Returns: |
258 |
| - The features from the images produced at the last timestep |
| 241 | + Calls :meth:`image` to get the images, then performs inference. On initialization, |
| 242 | + for a model zoo different from the default, define model_zoo_cfg: A dictionary with string keys and callable values. |
| 243 | + Should include "model", (mapped to a callable with no arguments to return the model), "preprocess" (mapped to |
| 244 | + a callable which consumes the images and returns the preprocessed images), |
| 245 | + and "inference" (mapped to a callable that provided the model, and the preproccessed images, returns the features.) |
259 | 246 | """
|
260 |
| - if not hasattr(image_features, "model_zoo"): |
261 |
| - image_features.model_zoo = {} |
262 |
| - |
263 |
| - if model_zoo_cfg is None: |
264 |
| - model_zoo_cfg = { |
265 |
| - "ResNet18": { |
266 |
| - "model": lambda: models.resnet18(pretrained=True).eval().to("cuda:0"), |
267 |
| - "preprocess": lambda img: ( |
268 |
| - img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width] |
269 |
| - - torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1) |
270 |
| - ) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1), |
271 |
| - "inference": lambda model, images: model(images), |
272 |
| - }, |
273 |
| - } |
274 |
| - |
275 |
| - if model_name not in image_features.model_zoo: |
276 |
| - print(f"[INFO]: Adding {model_name} to persistent frozen feature extraction model zoo...") |
277 |
| - image_features.model_zoo[model_name] = model_zoo_cfg[model_name]["model"]() |
278 |
| - |
279 |
| - images = image( |
280 |
| - env=env, |
281 |
| - sensor_cfg=sensor_cfg, |
282 |
| - data_type=data_type, |
283 |
| - convert_perspective_to_orthogonal=convert_perspective_to_orthogonal, |
284 |
| - normalize=True, # want this for training stability |
285 |
| - ) |
286 |
| - |
287 |
| - proc_images = model_zoo_cfg[model_name]["preprocess"](images) |
288 |
| - features = model_zoo_cfg[model_name]["inference"](image_features.model_zoo[model_name], proc_images) |
289 | 247 |
|
290 |
| - return features |
| 248 | + def __init__( |
| 249 | + self, |
| 250 | + cfg: ObservationTermCfg, |
| 251 | + env: ManagerBasedEnv, |
| 252 | + model_zoo_cfg: dict | None = None, |
| 253 | + initialize_all: bool = False, |
| 254 | + ): |
| 255 | + super().__init__(cfg, env) |
| 256 | + if model_zoo_cfg is None: |
| 257 | + self.model_zoo_cfg = { |
| 258 | + "ResNet18": { |
| 259 | + "model": lambda: models.resnet18(pretrained=True).eval().to("cuda:0"), |
| 260 | + "preprocess": lambda img: ( |
| 261 | + img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width] |
| 262 | + # Normalize in the format expected by pytorch; https://pytorch.org/hub/pytorch_vision_resnet/ |
| 263 | + - torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1) |
| 264 | + ) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1), |
| 265 | + "inference": lambda model, images: model(images), |
| 266 | + }, |
| 267 | + } |
| 268 | + self.reset_model(initialize_all=initialize_all) |
| 269 | + |
| 270 | + # The following is named reset_model instead of reset as otherwise it's called at the end of every episode |
| 271 | + def reset_model(self, initialize_all=False): |
| 272 | + self.model_zoo = {} |
| 273 | + if initialize_all: |
| 274 | + for model_name, model_callables in self.model_zoo_cfg.items(): |
| 275 | + self.model_zoo[model_name] = model_callables["model"]() |
| 276 | + |
| 277 | + def __call__( |
| 278 | + self, |
| 279 | + env: ManagerBasedEnv, |
| 280 | + sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"), |
| 281 | + data_type: str = "rgb", |
| 282 | + convert_perspective_to_orthogonal: bool = False, |
| 283 | + model_name: str = "ResNet18", |
| 284 | + ): |
| 285 | + if model_name not in self.model_zoo: |
| 286 | + print(f"[INFO]: Adding {model_name} to the model zoo") |
| 287 | + self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]() |
| 288 | + |
| 289 | + images = image( |
| 290 | + env=env, |
| 291 | + sensor_cfg=sensor_cfg, |
| 292 | + data_type=data_type, |
| 293 | + convert_perspective_to_orthogonal=convert_perspective_to_orthogonal, |
| 294 | + normalize=True, # want this for training stability |
| 295 | + ) |
| 296 | + |
| 297 | + proc_images = self.model_zoo_cfg[model_name]["preprocess"](images) |
| 298 | + features = self.model_zoo_cfg[model_name]["inference"](self.model_zoo[model_name], proc_images) |
| 299 | + |
| 300 | + return features |
291 | 301 |
|
292 | 302 |
|
293 | 303 | """
|
|
0 commit comments