Skip to content

[BUG] R3MTransform is wrongly used in "USING PRETRAINED MODELS" tutorial #3131

@Xmaster6y

Description

@Xmaster6y

Describe the bug

R3MTransform operated under @torch.no_grad() and thus cannot be trained contrary to what the tutorial is implying.

To Reproduce

import torch
from torch import nn
from torchrl.envs import R3MTransform, DTypeCastTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import Actor
from tensordict.nn import TensorDictSequential


base_env = GymEnv("Ant-v4", from_pixels=True)
env = TransformedEnv(base_env, DTypeCastTransform(dtype_in=torch.uint8, dtype_out=torch.float32, in_keys=["pixels"]))

r3m = R3MTransform("resnet50", in_keys=["pixels"], download=True)
net = nn.Sequential(
    nn.LazyLinear(128),
    nn.Tanh(),
    nn.Linear(128, env.action_spec.shape[-1]),
)
actor = Actor(net, in_keys=["r3m_vec"])
r3m.train()
policy = TensorDictSequential(r3m, actor)
rollout = base_env.rollout(10)

output = policy(rollout)
init_grad = torch.ones_like(output["action"])

output["action"].backward(init_grad, retain_graph=True)
for name, param in policy[0].named_parameters():
    if param.grad is not None:
        print(name) # nothing printed (silent failure)

torch.autograd.grad(output["action"], net[2].weight, init_grad, retain_graph=True) # works as expected
torch.autograd.grad(output["action"], policy[0].transforms[3].convnet.conv1.weight, init_grad, retain_graph=True) # raises

Expected behavior

I would have expected the transform to propagate gradients when not used as an environment transform (only).

Screenshots

NA

System info

Describe the characteristic of your environment:

  • torchrl v0.9.2
  • Python v3.10.11

Additional context

R3M also has other issues like wrong type hints and non-optional delete keys.

Reason and Possible fixes

Maybe create an R3MModule (from TensorDictModule) that could be used in the corresponding transform or in a policy as a feature extractor.

Maybe native support for (pretrained) feature extractor could be a nice to have in the modules, beyond transforms, e.g. like in SB3.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions