@@ -37,7 +37,6 @@ def __init__(self, activation_normalizer: ActivationNormalizer | None = None):
37
37
super ().__init__ ()
38
38
self .activation_normalizer = activation_normalizer
39
39
40
-
41
40
def normalize_activations (self , x : th .Tensor , inplace : bool = False ) -> th .Tensor :
42
41
"""
43
42
Normalize input activations using the configured normalizer.
@@ -596,13 +595,18 @@ def from_pretrained(
596
595
# Load activation normalizer if present in kwargs
597
596
activation_normalizer_mean = state_dict .get ("activation_normalizer.mean" , None )
598
597
activation_normalizer_std = state_dict .get ("activation_normalizer.std" , None )
599
- if activation_normalizer_mean is not None and activation_normalizer_std is not None :
598
+ if (
599
+ activation_normalizer_mean is not None
600
+ and activation_normalizer_std is not None
601
+ ):
600
602
activation_normalizer = ActivationNormalizer (
601
603
mean = activation_normalizer_mean , std = activation_normalizer_std
602
604
)
603
- else :
605
+ else :
604
606
activation_normalizer = None
605
- autoencoder = cls (activation_dim , dict_size , k , activation_normalizer = activation_normalizer )
607
+ autoencoder = cls (
608
+ activation_dim , dict_size , k , activation_normalizer = activation_normalizer
609
+ )
606
610
autoencoder .load_state_dict (state_dict )
607
611
if device is not None :
608
612
autoencoder .to (device )
@@ -843,7 +847,6 @@ def __init__(
843
847
self .weight = nn .Parameter (weight )
844
848
self .activation_normalizer = activation_normalizer
845
849
846
-
847
850
def forward (
848
851
self ,
849
852
f : th .Tensor ,
@@ -1268,7 +1271,10 @@ def from_pretrained(
1268
1271
# Load activation normalizer if present in kwargs
1269
1272
activation_normalizer_mean = state_dict .get ("activation_normalizer.mean" , None )
1270
1273
activation_normalizer_std = state_dict .get ("activation_normalizer.std" , None )
1271
- if activation_normalizer_mean is not None and activation_normalizer_std is not None :
1274
+ if (
1275
+ activation_normalizer_mean is not None
1276
+ and activation_normalizer_std is not None
1277
+ ):
1272
1278
activation_normalizer = ActivationNormalizer (
1273
1279
mean = activation_normalizer_mean , std = activation_normalizer_std
1274
1280
)
@@ -1669,7 +1675,10 @@ def from_pretrained(
1669
1675
# Load activation normalizer if present in kwargs
1670
1676
activation_normalizer_mean = state_dict .get ("activation_normalizer.mean" , None )
1671
1677
activation_normalizer_std = state_dict .get ("activation_normalizer.std" , None )
1672
- if activation_normalizer_mean is not None and activation_normalizer_std is not None :
1678
+ if (
1679
+ activation_normalizer_mean is not None
1680
+ and activation_normalizer_std is not None
1681
+ ):
1673
1682
activation_normalizer = ActivationNormalizer (
1674
1683
mean = activation_normalizer_mean , std = activation_normalizer_std
1675
1684
)
0 commit comments