Skip to content

Commit bf2238a

Browse files
committed
convert obs to class
1 parent d907803 commit bf2238a

File tree

2 files changed

+63
-52
lines changed

2 files changed

+63
-52
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ known_third_party = [
4747
"warp",
4848
"carb",
4949
"Semantics",
50+
"torchvision"
5051
]
5152
# Imports from this repository
5253
known_first_party = "omni.isaac.lab"

source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py

Lines changed: 62 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import omni.isaac.lab.utils.math as math_utils
1818
from omni.isaac.lab.assets import Articulation, RigidObject
1919
from omni.isaac.lab.managers import SceneEntityCfg
20+
from omni.isaac.lab.managers.manager_base import ManagerTermBase
21+
from omni.isaac.lab.managers.manager_term_cfg import ObservationTermCfg
2022
from omni.isaac.lab.sensors import Camera, RayCaster, RayCasterCamera, TiledCamera
2123

2224
if TYPE_CHECKING:
@@ -233,61 +235,69 @@ def image(
233235
return images.clone()
234236

235237

236-
def image_features(
237-
env: ManagerBasedEnv,
238-
sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"),
239-
data_type: str = "rgb",
240-
convert_perspective_to_orthogonal: bool = True,
241-
model_name: str = "Theia",
242-
model_zoo_cfg: dict | None = None,
243-
) -> torch.Tensor:
244-
"""Extracted image features with a frozen encoder from Images of a specific datatype from the camera sensor.
245-
246-
Args:
247-
env: The environment the cameras are placed within.
248-
sensor_cfg: The desired sensor to read from. Defaults to SceneEntityCfg("tiled_camera").
249-
data_type: The data type to pull from the desired camera. Defaults to "rgb".
250-
model_name: The name of which model to use from the model_zoo_cfg to use to extract features.
251-
model_zoo_cfg: A dictionary with string keys and callable values. Should include "model",
252-
(mapped to a callable with no arguments to return the model), "preprocess" (mapped to
253-
a callable which consumes the images and returns the preprocessed images),
254-
and "inference" (mapped to a callable that provided the model, and the preproccessed images,
255-
returns the features.)
238+
class image_features(ManagerTermBase):
239+
"""Extracted image features with a frozen encoder from images of a specific datatype from the camera sensor.
256240
257-
Returns:
258-
The features from the images produced at the last timestep
241+
Calls :meth:`image` to get the images, then performs inference. On initialization,
242+
for a model zoo different from the default, define model_zoo_cfg: A dictionary with string keys and callable values.
243+
Should include "model", (mapped to a callable with no arguments to return the model), "preprocess" (mapped to
244+
a callable which consumes the images and returns the preprocessed images),
245+
and "inference" (mapped to a callable that provided the model, and the preproccessed images, returns the features.)
259246
"""
260-
if not hasattr(image_features, "model_zoo"):
261-
image_features.model_zoo = {}
262-
263-
if model_zoo_cfg is None:
264-
model_zoo_cfg = {
265-
"ResNet18": {
266-
"model": lambda: models.resnet18(pretrained=True).eval().to("cuda:0"),
267-
"preprocess": lambda img: (
268-
img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width]
269-
- torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1)
270-
) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1),
271-
"inference": lambda model, images: model(images),
272-
},
273-
}
274-
275-
if model_name not in image_features.model_zoo:
276-
print(f"[INFO]: Adding {model_name} to persistent frozen feature extraction model zoo...")
277-
image_features.model_zoo[model_name] = model_zoo_cfg[model_name]["model"]()
278-
279-
images = image(
280-
env=env,
281-
sensor_cfg=sensor_cfg,
282-
data_type=data_type,
283-
convert_perspective_to_orthogonal=convert_perspective_to_orthogonal,
284-
normalize=True, # want this for training stability
285-
)
286-
287-
proc_images = model_zoo_cfg[model_name]["preprocess"](images)
288-
features = model_zoo_cfg[model_name]["inference"](image_features.model_zoo[model_name], proc_images)
289247

290-
return features
248+
def __init__(
249+
self,
250+
cfg: ObservationTermCfg,
251+
env: ManagerBasedEnv,
252+
model_zoo_cfg: dict | None = None,
253+
initialize_all: bool = False,
254+
):
255+
super().__init__(cfg, env)
256+
if model_zoo_cfg is None:
257+
self.model_zoo_cfg = {
258+
"ResNet18": {
259+
"model": lambda: models.resnet18(pretrained=True).eval().to("cuda:0"),
260+
"preprocess": lambda img: (
261+
img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width]
262+
# Normalize in the format expected by pytorch; https://pytorch.org/hub/pytorch_vision_resnet/
263+
- torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1)
264+
) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1),
265+
"inference": lambda model, images: model(images),
266+
},
267+
}
268+
self.reset_model(initialize_all=initialize_all)
269+
270+
# The following is named reset_model instead of reset as otherwise it's called at the end of every episode
271+
def reset_model(self, initialize_all=False):
272+
self.model_zoo = {}
273+
if initialize_all:
274+
for model_name, model_callables in self.model_zoo_cfg.items():
275+
self.model_zoo[model_name] = model_callables["model"]()
276+
277+
def __call__(
278+
self,
279+
env: ManagerBasedEnv,
280+
sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"),
281+
data_type: str = "rgb",
282+
convert_perspective_to_orthogonal: bool = False,
283+
model_name: str = "ResNet18",
284+
):
285+
if model_name not in self.model_zoo:
286+
print(f"[INFO]: Adding {model_name} to the model zoo")
287+
self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]()
288+
289+
images = image(
290+
env=env,
291+
sensor_cfg=sensor_cfg,
292+
data_type=data_type,
293+
convert_perspective_to_orthogonal=convert_perspective_to_orthogonal,
294+
normalize=True, # want this for training stability
295+
)
296+
297+
proc_images = self.model_zoo_cfg[model_name]["preprocess"](images)
298+
features = self.model_zoo_cfg[model_name]["inference"](self.model_zoo[model_name], proc_images)
299+
300+
return features
291301

292302

293303
"""

0 commit comments

Comments
 (0)