Skip to content

Commit 68b3025

Browse files
committed
format
1 parent e8db087 commit 68b3025

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

dictionary_learning/dictionary.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def __init__(self, activation_normalizer: ActivationNormalizer | None = None):
3737
super().__init__()
3838
self.activation_normalizer = activation_normalizer
3939

40-
4140
def normalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Tensor:
4241
"""
4342
Normalize input activations using the configured normalizer.
@@ -596,13 +595,18 @@ def from_pretrained(
596595
# Load activation normalizer if present in kwargs
597596
activation_normalizer_mean = state_dict.get("activation_normalizer.mean", None)
598597
activation_normalizer_std = state_dict.get("activation_normalizer.std", None)
599-
if activation_normalizer_mean is not None and activation_normalizer_std is not None:
598+
if (
599+
activation_normalizer_mean is not None
600+
and activation_normalizer_std is not None
601+
):
600602
activation_normalizer = ActivationNormalizer(
601603
mean=activation_normalizer_mean, std=activation_normalizer_std
602604
)
603-
else:
605+
else:
604606
activation_normalizer = None
605-
autoencoder = cls(activation_dim, dict_size, k, activation_normalizer=activation_normalizer)
607+
autoencoder = cls(
608+
activation_dim, dict_size, k, activation_normalizer=activation_normalizer
609+
)
606610
autoencoder.load_state_dict(state_dict)
607611
if device is not None:
608612
autoencoder.to(device)
@@ -843,7 +847,6 @@ def __init__(
843847
self.weight = nn.Parameter(weight)
844848
self.activation_normalizer = activation_normalizer
845849

846-
847850
def forward(
848851
self,
849852
f: th.Tensor,
@@ -1268,7 +1271,10 @@ def from_pretrained(
12681271
# Load activation normalizer if present in kwargs
12691272
activation_normalizer_mean = state_dict.get("activation_normalizer.mean", None)
12701273
activation_normalizer_std = state_dict.get("activation_normalizer.std", None)
1271-
if activation_normalizer_mean is not None and activation_normalizer_std is not None:
1274+
if (
1275+
activation_normalizer_mean is not None
1276+
and activation_normalizer_std is not None
1277+
):
12721278
activation_normalizer = ActivationNormalizer(
12731279
mean=activation_normalizer_mean, std=activation_normalizer_std
12741280
)
@@ -1669,7 +1675,10 @@ def from_pretrained(
16691675
# Load activation normalizer if present in kwargs
16701676
activation_normalizer_mean = state_dict.get("activation_normalizer.mean", None)
16711677
activation_normalizer_std = state_dict.get("activation_normalizer.std", None)
1672-
if activation_normalizer_mean is not None and activation_normalizer_std is not None:
1678+
if (
1679+
activation_normalizer_mean is not None
1680+
and activation_normalizer_std is not None
1681+
):
16731682
activation_normalizer = ActivationNormalizer(
16741683
mean=activation_normalizer_mean, std=activation_normalizer_std
16751684
)

0 commit comments

Comments
 (0)