@@ -26,7 +26,15 @@ class NormalizableMixin(nn.Module):
26
26
pass through unchanged.
27
27
"""
28
28
29
- def __init__ (self , activation_mean : th .Tensor | None = None , activation_std : th .Tensor | None = None , activation_shape : tuple [int , ...] | None = None ):
29
+ def __init__ (
30
+ self ,
31
+ activation_mean : th .Tensor | None = None ,
32
+ activation_std : th .Tensor | None = None ,
33
+ activation_shape : tuple [int , ...] | None = None ,
34
+ * ,
35
+ keep_relative_variance : bool = True ,
36
+ target_rms : float = 1.0 ,
37
+ ):
30
38
"""
31
39
Initialize the normalization mixin.
32
40
@@ -36,26 +44,44 @@ def __init__(self, activation_mean: th.Tensor | None = None, activation_std: th.
36
44
activation_std: Optional std tensor for normalization. If None,
37
45
normalization is a no-op.
38
46
activation_shape: Shape of the activation tensor. Required if activation_mean and activation_std are None for proper initialization and registration of the buffers.
47
+ keep_relative_variance: If True, performs global scaling so that the
48
+ sum of variances is 1 while their relative magnitudes stay unchanged. If false we normalize neuron-wise.
49
+ target_rms: Target RMS for input activation normalization.
39
50
"""
40
51
super ().__init__ ()
52
+ self .keep_relative_variance = keep_relative_variance
53
+ self .register_buffer ("target_rms" , th .tensor (target_rms ))
41
54
if activation_mean is not None and activation_std is not None :
42
55
# Type assertion to help linter understand these are tensors
43
- assert isinstance (activation_mean , th .Tensor ), "Expected mean to be a tensor"
56
+ assert isinstance (
57
+ activation_mean , th .Tensor
58
+ ), "Expected mean to be a tensor"
44
59
assert isinstance (activation_std , th .Tensor ), "Expected std to be a tensor"
45
60
assert not th .isnan (activation_mean ).any (), "Expected mean to be non-NaN"
46
61
assert not th .isnan (activation_std ).any (), "Expected std to be non-NaN"
47
62
self .register_buffer ("activation_mean" , activation_mean )
48
63
self .register_buffer ("activation_std" , activation_std )
49
64
else :
50
- assert activation_shape is not None , "activation_shape must be provided if activation_mean and activation_std are None"
65
+ assert (
66
+ activation_shape is not None
67
+ ), "activation_shape must be provided if activation_mean and activation_std are None"
51
68
self .register_buffer ("activation_mean" , th .nan * th .ones (activation_shape ))
52
69
self .register_buffer ("activation_std" , th .nan * th .ones (activation_shape ))
53
70
71
+ if self .keep_relative_variance and self .has_activation_normalizer :
72
+ total_var = (self .activation_std ** 2 ).sum ()
73
+ activation_global_scale = self .target_rms / th .sqrt (total_var + 1e-8 )
74
+ self .register_buffer ("activation_global_scale" , activation_global_scale )
75
+ else :
76
+ self .register_buffer ("activation_global_scale" , th .tensor (1.0 ))
77
+
54
78
@property
55
79
def has_activation_normalizer (self ) -> bool :
56
80
"""Check if activation normalization is enabled."""
57
- return (not th .isnan (self .activation_mean ).any () and
58
- not th .isnan (self .activation_std ).any ())
81
+ return (
82
+ not th .isnan (self .activation_mean ).any ()
83
+ and not th .isnan (self .activation_std ).any ()
84
+ )
59
85
60
86
def normalize_activations (self , x : th .Tensor , inplace : bool = False ) -> th .Tensor :
61
87
"""
@@ -74,7 +100,12 @@ def normalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Tenso
74
100
# Type assertions for linter
75
101
assert isinstance (self .activation_mean , th .Tensor )
76
102
assert isinstance (self .activation_std , th .Tensor )
77
- return (x - self .activation_mean ) / (self .activation_std + 1e-8 )
103
+ x = x - self .activation_mean
104
+
105
+ if self .keep_relative_variance :
106
+ return x * self .activation_global_scale
107
+ else :
108
+ return x / (self .activation_std + 1e-8 )
78
109
return x
79
110
80
111
def denormalize_activations (self , x : th .Tensor , inplace : bool = False ) -> th .Tensor :
@@ -94,7 +125,13 @@ def denormalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Ten
94
125
# Type assertions for linter
95
126
assert isinstance (self .activation_mean , th .Tensor )
96
127
assert isinstance (self .activation_std , th .Tensor )
97
- return x * (self .activation_std + 1e-8 ) + self .activation_mean
128
+
129
+ if self .keep_relative_variance :
130
+ x = x / (self .activation_global_scale + 1e-8 )
131
+ else :
132
+ x = x * (self .activation_std + 1e-8 )
133
+
134
+ return x + self .activation_mean
98
135
return x
99
136
100
137
@@ -454,6 +491,8 @@ def __init__(
454
491
k : int ,
455
492
activation_mean : th .Tensor | None = None ,
456
493
activation_std : th .Tensor | None = None ,
494
+ target_rms : float = 1.0 ,
495
+ encoder_init_norm : float = 1.0 ,
457
496
):
458
497
"""
459
498
Initialize the Batch Top-K SAE.
@@ -464,11 +503,17 @@ def __init__(
464
503
k: Number of top features to keep active across the batch
465
504
activation_mean: Optional mean tensor for input activation normalization. If None, no normalization is applied.
466
505
activation_std: Optional std tensor for input activation normalization. If None, no normalization is applied.
506
+ target_rms: Target variance for input activation normalization.
507
+ encoder_init_norm: Norm for the encoder weights.
467
508
"""
468
509
469
- super ().__init__ (activation_mean = activation_mean , activation_std = activation_std , activation_shape = (activation_dim ,))
470
-
471
-
510
+ super ().__init__ (
511
+ activation_mean = activation_mean ,
512
+ activation_std = activation_std ,
513
+ activation_shape = (activation_dim ,),
514
+ target_rms = target_rms ,
515
+ )
516
+
472
517
self .activation_dim = activation_dim
473
518
self .dict_size = dict_size
474
519
@@ -482,7 +527,7 @@ def __init__(
482
527
)
483
528
484
529
self .encoder = nn .Linear (activation_dim , dict_size )
485
- self .encoder .weight .data = self .decoder .weight .T .clone ()
530
+ self .encoder .weight .data = self .decoder .weight .T .clone () * encoder_init_norm
486
531
self .encoder .bias .data .zero_ ()
487
532
self .b_dec = nn .Parameter (th .zeros (activation_dim ))
488
533
@@ -627,10 +672,10 @@ def from_pretrained(
627
672
elif "k" in state_dict and k != state_dict ["k" ].item ():
628
673
raise ValueError (f"k={ k } != { state_dict ['k' ].item ()} =state_dict['k']" )
629
674
630
-
631
-
632
675
autoencoder = cls (
633
- activation_dim , dict_size , k ,
676
+ activation_dim ,
677
+ dict_size ,
678
+ k ,
634
679
)
635
680
autoencoder .load_state_dict (state_dict )
636
681
if device is not None :
@@ -645,6 +690,7 @@ def dtype(self):
645
690
def device (self ):
646
691
return self .encoder .weight .device
647
692
693
+
648
694
# TODO merge this with AutoEncoder
649
695
class AutoEncoderNew (Dictionary , nn .Module ):
650
696
"""
@@ -994,6 +1040,7 @@ class CrossCoder(Dictionary, NormalizableMixin):
994
1040
code_normalization_alpha_cc: Weight for CrossCoder component in MIXED normalization
995
1041
activation_mean: Optional mean tensor for input/output activation normalization
996
1042
activation_std: Optional std tensor for input/output activation normalization
1043
+ target_rms: Optional target RMS for input/output activation normalization
997
1044
"""
998
1045
999
1046
def __init__ (
@@ -1012,6 +1059,7 @@ def __init__(
1012
1059
code_normalization_alpha_cc : float | None = 0.1 ,
1013
1060
activation_mean : th .Tensor | None = None ,
1014
1061
activation_std : th .Tensor | None = None ,
1062
+ target_rms : float | None = None ,
1015
1063
):
1016
1064
"""
1017
1065
Initialize a CrossCoder sparse autoencoder.
@@ -1031,11 +1079,16 @@ def __init__(
1031
1079
code_normalization_alpha_cc: Weight for CrossCoder component in MIXED normalization
1032
1080
activation_mean: Optional mean tensor for input/output activation normalization
1033
1081
activation_std: Optional std tensor for input/output activation normalization
1082
+ target_rms: Optional target RMS for input/output activation normalization
1034
1083
"""
1035
1084
# First initialize the base classes that don't take normalization parameters
1036
- super ().__init__ (activation_mean = activation_mean , activation_std = activation_std , activation_shape = (num_layers , activation_dim ))
1085
+ super ().__init__ (
1086
+ activation_mean = activation_mean ,
1087
+ activation_std = activation_std ,
1088
+ activation_shape = (num_layers , activation_dim ),
1089
+ target_rms = target_rms ,
1090
+ )
1037
1091
1038
-
1039
1092
if num_decoder_layers is None :
1040
1093
num_decoder_layers = num_layers
1041
1094
@@ -1306,7 +1359,7 @@ def dtype(self):
1306
1359
@property
1307
1360
def device (self ):
1308
1361
return self .encoder .weight .device
1309
-
1362
+
1310
1363
def resample_neurons (self , deads , activations ):
1311
1364
"""
1312
1365
Resample dead neurons by reinitializing their weights.
@@ -1401,6 +1454,7 @@ def __init__(
1401
1454
norm_init_scale: Scale factor for weight initialization normalization
1402
1455
activation_mean: Optional mean tensor for input/output activation normalization
1403
1456
activation_std: Optional std tensor for input/output activation normalization
1457
+ target_rms: Optional target RMS for input/output activation normalization
1404
1458
*args: Additional positional arguments passed to parent class
1405
1459
**kwargs: Additional keyword arguments passed to parent class
1406
1460
"""
@@ -1411,6 +1465,7 @@ def __init__(
1411
1465
norm_init_scale = norm_init_scale ,
1412
1466
activation_mean = activation_mean ,
1413
1467
activation_std = activation_std ,
1468
+ target_rms = target_rms ,
1414
1469
* args ,
1415
1470
** kwargs ,
1416
1471
)
@@ -1687,7 +1742,6 @@ def from_pretrained(
1687
1742
), f"k in kwargs ({ kwargs ['k' ]} ) does not match k in state_dict ({ state_dict ['k' ]} )"
1688
1743
kwargs .pop ("k" )
1689
1744
1690
-
1691
1745
crosscoder = cls (
1692
1746
activation_dim ,
1693
1747
dict_size ,
0 commit comments